[
  {
    "path": ".github/dependabot.yml",
    "content": "version: 2\nupdates:\n  - package-ecosystem: \"pip\"\n    directory: \"/\"\n    schedule:\n      interval: \"weekly\"\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"weekly\"\n"
  },
  {
    "path": ".github/workflows/tests.yml",
    "content": "name: Tests\n\non:\n  push:\n    branches:\n      - main\n      - v*-release\n  pull_request:\n    branches:\n      - main\n\njobs:\n\n  tests:\n    name: Run tests and quality checks\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n      - name: Setup Python environment\n        uses: actions/setup-python@v5\n        with:\n          python-version: 3.10.10\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          python -m pip install \".[quality,tests]\"\n      - name: Code quality\n        run: |\n          make quality\n      - name: Run tests\n        run: |\n          make test\n\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# UV\n#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#uv.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n# PyPI configuration file\n.pypirc\n\n# Temp folders\ndata/\nwandb/\nlogs/\neval_results/\nresults/\n\n.vscode/\n.python-version"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "Makefile",
    "content": ".PHONY: style quality\n\n# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)\nexport PYTHONPATH = src\n\ncheck_dirs := src tests\n\n\n# dev dependencies\ninstall:\n\tuv venv openr1 --python 3.11\n\t. openr1/bin/activate && uv pip install --upgrade pip && \\\n\tuv pip install vllm==0.8.5.post1 && \\\n\tuv pip install setuptools && \\\n\tuv pip install flash-attn --no-build-isolation && \\\n\tGIT_LFS_SKIP_SMUDGE=1 uv pip install -e \".[dev]\"\n\nstyle:\n\truff format --line-length 119 --target-version py310 $(check_dirs) setup.py\n\tisort $(check_dirs) setup.py\n\nquality:\n\truff check --line-length 119 --target-version py310 $(check_dirs) setup.py\n\tisort --check-only $(check_dirs) setup.py\n\tflake8 --max-line-length 119 $(check_dirs) setup.py\n\ntest:\n\tpytest -sv --ignore=tests/slow/ tests/\n\nslow_test:\n\tpytest -sv -vv tests/slow/\n\n# Evaluation\n\nevaluate:\n\t$(eval PARALLEL_ARGS := $(if $(PARALLEL),$(shell \\\n\t\tif [ \"$(PARALLEL)\" = \"data\" ]; then \\\n\t\t\techo \"data_parallel_size=$(NUM_GPUS)\"; \\\n\t\telif [ \"$(PARALLEL)\" = \"tensor\" ]; then \\\n\t\t\techo \"tensor_parallel_size=$(NUM_GPUS)\"; \\\n\t\tfi \\\n\t),))\n\t$(if $(filter tensor,$(PARALLEL)),export VLLM_WORKER_MULTIPROC_METHOD=spawn &&,) \\\n\tMODEL_ARGS=\"pretrained=$(MODEL),dtype=bfloat16,$(PARALLEL_ARGS),max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\" && \\\n\tif [ \"$(TASK)\" = \"lcb\" ]; then \\\n\t\tlighteval vllm $$MODEL_ARGS \"extended|lcb:codegeneration|0|0\" \\\n\t\t\t--use-chat-template \\\n\t\t\t--output-dir data/evals/$(MODEL); \\\n\telse \\\n\t\tlighteval vllm $$MODEL_ARGS \"lighteval|$(TASK)|0|0\" \\\n\t\t\t--use-chat-template \\\n\t\t\t--output-dir data/evals/$(MODEL); \\\n\tfi\n"
  },
  {
    "path": "README.md",
    "content": "# Open R1\n\n*A fully open reproduction of DeepSeek-R1. This repo is a work in progress, let's build it together!*\n\n**Table of Contents**  \n1. [Overview](#overview)  \n2. [Plan of attack](#plan-of-attack)  \n3. [Installation](#installation)  \n4. [Training models](#training-models)  \n   - [SFT](#sft)  \n   - [GRPO](#grpo)  \n5. [Evaluating models](#evaluating-models)  \n6. [Reproducing Deepseek's evaluation results](#reproducing-deepseeks-evaluation-results)  \n7. [Data generation](#data-generation)  \n   - [Generate data from a smol distilled R1 model](#generate-data-from-a-smol-distilled-r1-model)  \n   - [Generate data from DeepSeek-R1](#generate-data-from-deepseek-r1)  \n8. [Contributing](#contributing)\n\n## Overview\n\nThe goal of this repo is to build the missing pieces of the R1 pipeline such that everybody can reproduce and build on top of it. The project is simple by design and mostly consists of:\n\n\n- `src/open_r1`: contains the scripts to train models as well as generate synthetic data:\n    - `grpo.py`: trains a model with GRPO on a given dataset.\n    - `sft.py`: performs a simple SFT of a model on a dataset.\n    - `generate.py`: generates synthetic data from a model using [Distilabel](https://github.com/argilla-io/distilabel).\n- `Makefile`: contains easy-to-run commands for each step in the R1 pipeline leveraging the scripts above.\n\n### Plan of attack\n\nWe will use the DeepSeek-R1 [tech report](https://github.com/deepseek-ai/DeepSeek-R1) as a guide, which can roughly be broken down into three main steps:\n\n* Step 1: replicate the R1-Distill models by distilling a high-quality corpus from DeepSeek-R1.\n* Step 2: replicate the pure RL pipeline that DeepSeek used to create R1-Zero. This will likely involve curating new, large-scale datasets for math, reasoning, and code.\n* Step 3: show we can go from base model to RL-tuned via multi-stage training.\n\n<center>\n    <img src=\"assets/plan-of-attack.png\" width=\"500\">\n</center>\n\n## News 🗞️\n\n* **🧑‍🍳 [2025/05/26] (Step 1 completed!)** We release [**Mixture-of-Thoughts**](https://huggingface.co/datasets/open-r1/Mixture-of-Thoughts)--a curated reasoning dataset of 350k verified traces distilled from R1. The dataset spans tasks in mathematics, coding, and science, and is designed to teach language models to reason step-by-step. We also provide a recipe to train [OpenR1-Distill-7B](https://huggingface.co/open-r1/OpenR1-Distill-7B), which replicates the reasoning capabilities of [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) and marks the completion of step 1 in the Open R1 project.\n* **⚡️ [2025/03/11] [(update #3)](https://huggingface.co/blog/open-r1/update-3):** We release the [**CodeForces-CoTs**](https://huggingface.co/datasets/open-r1/codeforces-cots) dataset of 10k competitive programming problems and 100k solutions distilled from R1. We also release IOI24: a new benchmark of _very_ hard problems from international olympiads. A 7B Qwen model trained on CodeForces-CoTs can outperform Claude 3.7 Sonnet on IOI24, while a 32B model can outperform R1 itself.\n* **∞ [2025/02/10] [(update #2)](https://huggingface.co/blog/open-r1/update-2):** We release the [**OpenR1-Math-220k**](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k) dataset of 220k traces distilled from R1 on a new version of NuminaMath. Models trained on this dataset match the performance of DeepSeek's distilled ones.\n* **🔥 [2025/02/02] [(update #1)](https://huggingface.co/blog/open-r1/update-1):** We implement the first parts of the [training](https://github.com/huggingface/open-r1?tab=readme-ov-file#training-models), [inference](https://github.com/huggingface/open-r1?tab=readme-ov-file#data-generation), and [evaluation](https://github.com/huggingface/open-r1?tab=readme-ov-file#reproducing-deepseeks-evaluation-results) pipelines. Let's go!  \n\n## Installation\n\n> [!CAUTION]\n> Libraries rely on CUDA 12.4. If you see errors related to segmentation faults, double check the version your system is running with `nvcc --version`.\n\nTo run the code in this project, first, create a Python virtual environment using e.g. `uv`.\nTo install `uv`, follow the [UV Installation Guide](https://docs.astral.sh/uv/getting-started/installation/).\n\n\n> [!NOTE]\n> As a shortcut, run `make install` to setup development libraries (spelled out below). Afterwards, if everything is setup correctly you can try out the Open-R1 models.\n\n\n```shell\nuv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip\n```\n\n> [!TIP]\n> For Hugging Face cluster users, add `export UV_LINK_MODE=copy` to your `.bashrc` to suppress cache warnings from `uv`\n\nNext, install vLLM and FlashAttention:\n\n```shell\nuv pip install vllm==0.8.5.post1\nuv pip install setuptools && uv pip install flash-attn --no-build-isolation\n```\n\nThis will also install PyTorch `v2.6.0` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:\n\n```shell\nGIT_LFS_SKIP_SMUDGE=1 uv pip install -e \".[dev]\"\n```\n\nNext, log into your Hugging Face and Weights and Biases accounts as follows:\n\n```shell\nhuggingface-cli login\nwandb login\n```\n\nFinally, check whether your system has Git LFS installed so that you can load and push models/datasets to the Hugging Face Hub:\n\n```shell\ngit-lfs --version\n```\n\nIf it isn't installed, run:\n\n```shell\nsudo apt-get install git-lfs\n```\n\n## Training models\n\n> [!NOTE]\n> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.\n\nWe support training models with either DDP or DeepSpeed (ZeRO-2 and ZeRO-3). For example, to perform SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/Mixture-of-Thoughts](https://huggingface.co/datasets/open-r1/Mixture-of-Thoughts), run:\n\n```shell\n# Train via command line\naccelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \\\n    --model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \\\n    --dataset_name open-r1/Mixture-of-Thoughts \\\n    --dataset_config all \\\n    --eos_token '<|im_end|>' \\\n    --learning_rate 4.0e-5 \\\n    --num_train_epochs 5 \\\n    --max_seq_length 32768 \\\n    --per_device_train_batch_size 2 \\\n    --gradient_checkpointing \\\n    --bf16 \\\n    --use_liger_kernel \\\n    --output_dir data/OpenR1-Distill-7B\n\n# Train via YAML config\naccelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \\\n    --config recipes/OpenR1-Distill-7B/sft/config_distill.yaml\n```\n\nCurrently, the following tasks are supported:\n\n* Supervised Fine-Tuning `sft`\n* Group Relative Policy Optimization `grpo`\n\n> [!TIP]\n> If you scale up/down the number of GPUs, we recommend also scaling up the per-device batch size or number of gradient accumulation steps to keep the global batch size constant.\n\nBy default, these scripts will push each model to your Hugging Face Hub username, i.e. `{username}/{model_name}-{task}`. You can override the parameters in each YAML config by appending them to the command as follows: \n\n```shell\n# Change the base model to a smaller variant\naccelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \\\n    --config recipes/OpenR1-Distill-7B/sft/config_distill.yaml \\\n    --model_name_or_path Qwen/Qwen3-0.6B-Base \\\n    --hub_model_id OpenR1-Distill-0.6B \\\n    --output_dir data/OpenR1-Distill-0.6B\n```\n\nIf you also wish to override the Weights and Biases default settings, you can do so as follows:\n\n```shell\naccelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \\\n    --config recipes/OpenR1-Distill-7B/sft/config_distill.yaml\n    --wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO\n```\n\n**🚨 WARNING 🚨**\n\nMost base models like `meta-llama/Llama-3.2-1B` do not have a chat template, so we set ChatML as the default during training. However, for Qwen base models like `Qwen/Qwen2.5-1.5B`, a chat template is pre-defined in the tokenizer, so the EOS token must be set accordingly, e.g.\n\n```diff\n# Align EOS token with chat template for Qwen base models\naccelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \\\n    --model_name_or_path Qwen/Qwen2.5-1.5B \\\n+   --eos_token '<|im_end|>'\n    --dataset_name open-r1/Mixture-of-Thoughts \\\n    --dataset_config all \\\n    --learning_rate 4.0e-5 \\\n    --num_train_epochs 1 \\\n    --max_seq_length 32768 \\\n    --per_device_train_batch_size 16 \\\n    --gradient_checkpointing \\\n    --bf16 \\\n    --use_liger_kernel \\\n    --output_dir data/Qwen2.5-1.5B-Open-R1-Distill\n```\n\nIf you wish to use a custom chat template (e.g. Llama or Gemma), then the chat template and associated EOS token must be provided:\n\n```diff\n# Align EOS token with custom chat template\naccelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \\\n    --model_name_or_path meta-llama/Llama-3.2-1B \\\n+   --chat_template \"$(cat llama_chat_template.jinja)\" \\\n+   --eos_token '<|eot_id|>' \\\n    --dataset_name open-r1/Mixture-of-Thoughts \\\n    --dataset_config all \\\n    --learning_rate 4.0e-5 \\\n    --num_train_epochs 1 \\\n    --max_seq_length 32768 \\\n    --per_device_train_batch_size 16 \\\n    --gradient_checkpointing \\\n    --bf16 \\\n    --use_liger_kernel \\\n    --output_dir data/Llama-3.2-1B-Open-R1-Distill\n```\n\n### SFT distillation\n\nWe provide a recipe to reproduce the reasoning capabilities of [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B), starting from the same base model. To do so, run:\n\n```shell\nACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \\\n    src/open_r1/sft.py \\\n    --config recipes/OpenR1-Distill-7B/sft/config_distill.yaml\n```\n\nThe result will be a model like [open-r1/OpenR1-Distill-7B](https://huggingface.co/open-r1/OpenR1-Distill-7B), with the following downstream performance:\n\n| Model                       | AIME 2024 | MATH-500 | GPQA Diamond | LiveCodeBench v5 |\n|-----------------------------|-----------|----------|--------------|------------------|\n| OpenR1-Distill-7B           | 52.7      | 89.0     | 52.8         | 39.4             |\n| DeepSeek-R1-Distill-Qwen-7B | 51.3      | 93.5     | 52.4         | 37.4             |\n\nYou can adjust the YAML config to train on a different base model or dataset.\n\n### GRPO\n\nWe use TRL's [vLLM backend](https://huggingface.co/docs/trl/speeding_up_training?vllm+examples=GRPO#vllm-for-fast-generation-in-online-methods) to scale training to large models across multiple nodes. For single-node training of smol models across 8 GPUs, use `vllm_mode=\"colocate\"` to run vLLM in the same process as the training script:\n\n```shell\nACCELERATE_LOG_LEVEL=info \\\n    accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \\\n    src/open_r1/grpo.py --config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml \\\n    --vllm_mode colocate\n```\n\n> [!WARNING]\n> The chat template used in the distilled DeepSeek models omits the contents of the reasoning block within the `<think>` and `</think>` tags. It also prefills the assistant response with `<think>` which interferes with the format reward function. To handle that, it is important to override the chat template as done in e.g.  [recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml](./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml).\n\nFor multi-node training on N+1 nodes, with 1 node running the vLLM server and N nodes running training, we provide an example Slurm script. For example, to run the above example on 1+1 nodes with data parallelism, run:\n\n```shell\nsbatch --nodes=2 slurm/train.slurm --model Qwen2.5-1.5B-Instruct --task grpo --config demo --accelerator zero2 --dp 8 --tp 1\n```\n\nSee the [Launching jobs on a Slurm cluster](#launching-jobs-on-a-slurm-cluster) section for more details.\n\n#### GRPO dataset filtering\n\nWe provide support to filter datasets by generating and computing pass rate on veriable tasks, see this [README](scripts/pass_rate_filtering/README.md)\n\n#### 👨‍💻 Training with a code interpreter\n\nWe provide a `code` reward function for executing code generated by the policy during training. Currently, this reward function targets code contests like [Codeforces](https://codeforces.com), where solutions are executed against a set of test cases and the overall success rate is returned as the final reward. To ensure safe execution, we support multiple sandbox providers:\n\n1. [E2B](https://e2b.dev) - Fast, cloud-based sandboxes with focus on Python execution\n2. [Morph](https://cloud.morph.so/web/) - Cloud-based sandboxes with broader language support - Python/JS/C++/Rust\n\nTo use the code reward function, first install the necessary dependencies:\n\n```shell\nuv pip install -e '.[code]'\n```\n\n##### E2B Provider\n\nTo use E2B sandboxes, create a `.env` file and add your E2B API token:\n\n```\nE2B_API_KEY=\"e2b_xxx\"\n```\n\n##### Morph Provider\n\nTo use Morph, first install the morphcloud package:\n\n```shell\npip install morphcloud\n```\n\nThen add your Morph API token to the `.env` file:\n\n```\nMORPH_API_KEY=\"YOUR_MORPH_API_KEY\"\n```\n\nTo specify which provider to use, add the `provider_type` parameter in your configuration:\n\n```yaml\n# For E2B\nprovider_type: e2b\n\n# For Morph\nprovider_type: morph\n```\n\n##### Dataset Requirements\n\nMake sure your dataset contains a `verification_info` column with the following schema (adopted from PrimeIntellect's excellent [datasets](https://huggingface.co/collections/PrimeIntellect/synthetic-1-67a2c399cfdd6c9f7fae0c37) of verifiable problems):\n\n```python\n{\n    \"language\": \"python\",  # Morph supports more languages including C++, Java, etc.\n    \"test_cases\": [\n        {\n            \"input\": \"4\\n4\\n0001\\n1000\\n0011\\n0111\\n3\\n010\\n101\\n0\\n2\\n00000\\n00001\\n4\\n01\\n001\\n0001\\n00001\\n\",\n            \"output\": \"1\\n3 \\n-1\\n0\\n\\n2\\n1 2 \\n\",\n            \"type\": \"stdin_stdout\",\n        }\n    ],\n}\n```\n\nFor example, to train a smol model on Python problems, start the vLLM server:\n\n```shell\nCUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-1.5B-Instruct\n```\n\nThen run training with:\n\n```shell\nCUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info \\\n    accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes=7 \\\n    src/open_r1/grpo.py --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml\n```\n\n##### Using Router Services\n\nIt is possible to be rate limited when too many scripts are executed on sandbox services. For both providers, we offer router scripts that can be launched on a CPU node:\n\nFor E2B:\n```shell\nsbatch slurm/e2b_router.slurm\n```\n\nFor Morph:\n```shell\nsbatch slurm/morph_router.slurm\n```\n\nThen add the router URL in your training YAML config:\n```yaml\n# For E2B\ne2b_router_url: 1.2.3.4:8000\n\n# For Morph\nmorph_router_url: 1.2.3.4:8000\n```\n\nThe port should match the one used when launching the router.\nAll training jobs can share the same router IP which will ensure parallel executions are properly managed.\n\n#### Competitive Programming problems: IOI & CodeForces\n\nWe provide `ioi_code_reward` and `cf_code_reward` reward functions for executing problems from [IOI](https://hf.co/datasets/open-r1/ioi) and [CodeForces](https://huggingface.co/datasets/open-r1/codeforces), respectively. You can use either [piston](https://github.com/engineer-man/piston) or Morph (currently IOI only) as your execution provider.\n\n##### Piston \n\nTo use Piston:\n1. Get piston workers running, see [slurm/piston/README.md](./slurm/piston/README.md)\n2. Set your environment variable `PISTON_ENDPOINTS` to `slurm` or to a list of piston worker endpoints\n\nFor IOI:\n\n3. In your configuration, use `ioi_provider: \"piston\"`\n\nFor CodeForces:\n\n3. Download the generated (hard) test cases:\n```\n# change PATH_TO_SAVE_TESTCASES. Increase --max-workers according to your machine's capacity\nhuggingface-cli download open-r1/codeforces --repo-type=dataset --include='generated_tests/*.parquet' --max-workers=8 --local-dir PATH_TO_SAVE_TESTCASES \n```\n4. Save the path in .env:\n```\nCF_TESTS_FOLDER=PATH_TO_SAVE_TESTCASES\n```\n\n##### Morph \n\nMorph is a cloud-based solution that provides sandboxed environments for running code. To use it:\n1. Install the Morph client: `pip install morphcloud`\n2. Add your Morph API key to the `.env` file: `MORPH_API_KEY=\"your_key_here\"`\n3. In your configuration, use `ioi_provider: \"morph\"`\n\n##### Example recipes\nFor IOI:\n\nSee the [example recipe](./recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml) for how to use the IOI reward function:\n\n```shell\nACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \\\n    --num_processes=7 src/open_r1/grpo.py \\\n    --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml\n```\n\nFor CodeForces:\n\n```shell\nsbatch --job-name=cf-grpo --nodes=2 slurm/train.slurm --model Qwen2.5-Coder-7B-Instruct --task grpo --config codeforces --accelerator zero3 --dp 8 --tp 1\n```\n\n### Launching jobs on a Slurm cluster\n\nIf you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it:\n\n```shell\nsbatch --job-name=open_r1 --nodes=1 slurm/train.slurm --model {model_name} --task {task} --config {config_suffix} --accelerator {accelerator}\n```\n\nHere `{model_name}` and `{task}` are defined as above, while `{config_suffix}` refers to the specific config and `{accelerator}` refers to the choice of 🤗 Accelerate config in `recipes/accelerate_configs`. If you wish to override the default config parameters, you can provide them by appending a space-separated string like `'--arg1=value1 --arg2=value2'`. Here's a concrete example to run SFT on 1 node of 8 GPUs:\n\n```shell\nsbatch --job-name=open_r1 --nodes=1 slurm/train.slurm --model OpenR1-Distill-7B --task sft --config distill --accelerator zero3\n```\n\nYou can scale the number of nodes by increasing the `--nodes` flag.\n\nFor GRPO, we use 1 node for the vLLM server and N nodes for training. For example, to run GRPO on 1+1 nodes with mixed data and tensor parallelism, run:\n\n```shell\nsbatch --job-name=open_r1 --nodes=2 slurm/train.slurm --model Qwen2.5-1.5B-Instruct --task grpo --config demo --accelerator zero2 --dp 4 --tp 2\n```\n\n> [!NOTE]\n> The configuration in `slurm/train.slurm` is optimised for the Hugging Face Compute Cluster and may require tweaking to be adapted to your own compute nodes.\n\n### Customising the dataset mixture\n\nTo combine multiple datasets as a single training mixture, you can specify the `dataset_mixture` parameter in the YAML config file. Here's a template for how to do this:\n\n```yaml\ndataset_mixture:\n  datasets:                     # List of datasets to include in the mixture\n    - id: dataset_1             # Hub dataset ID\n      config: config_name_1     # Name of the dataset config\n      split: split_1            # Split to use from the dataset\n      columns:                  # Columns to keep\n        - column_1              \n        - column_2    \n      weight: 0.25              # Fraction of dataset to use\n    - id: dataset_2\n      config: config_name_2\n      split: split_2\n      columns:                  \n        - column_1              \n        - column_2   \n      weight: 0.5\n  seed: 42                      # Seed for shuffling the combined dataset\n  test_split_size: 0.1          # Fraction of mixture to use for a test split\n```\n\n## Evaluating models\n\nWe use `lighteval` to evaluate models. For models which fit on a single GPU, run:\n\n```shell\nexport VLLM_WORKER_MULTIPROC_METHOD=spawn # Required for vLLM\nMODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\nMODEL_ARGS=\"model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\"\nOUTPUT_DIR=data/evals/$MODEL\n\n# AIME 2024\nTASK=aime24\nlighteval vllm $MODEL_ARGS \"lighteval|$TASK|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR\n\n# MATH-500\nTASK=math_500\nlighteval vllm $MODEL_ARGS \"lighteval|$TASK|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR\n\n# GPQA Diamond\nTASK=gpqa:diamond\nlighteval vllm $MODEL_ARGS \"lighteval|$TASK|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR\n\n# LiveCodeBench\nlighteval vllm $MODEL_ARGS \"extended|lcb:codegeneration|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR \n```\n\nTo increase throughput across multiple GPUs, use _data parallel_ as follows:\n\n```shell\nNUM_GPUS=8\nMODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\nMODEL_ARGS=\"model_name=$MODEL,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\"\nTASK=aime24\nOUTPUT_DIR=data/evals/$MODEL\n\nlighteval vllm $MODEL_ARGS \"lighteval|$TASK|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR\n```\n\nFor large models which require sharding across GPUs, use _tensor parallel_ and run:\n\n```shell\nNUM_GPUS=8\nMODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B\nMODEL_ARGS=\"model_name=$MODEL,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\"\nTASK=aime24\nOUTPUT_DIR=data/evals/$MODEL\n\nexport VLLM_WORKER_MULTIPROC_METHOD=spawn\nlighteval vllm $MODEL_ARGS \"lighteval|$TASK|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR\n```\n\nYou can also launch an evaluation with `make evaluate`, specifying the model, task, and optionally the parallelism technique and number of GPUs.\n\nTo evaluate on a single GPU:\n\n```shell\nmake evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24\n```\n\nTo use Data Parallelism:\n\n```shell\nmake evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=data NUM_GPUS=8\n```\n\nTo use Tensor Parallelism:\n\n```shell\nmake evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=tensor NUM_GPUS=8\n```\n\n## Reproducing Deepseek's evaluation results\n\nThe DeepSeek-R1 paper uses sampling with 4-64 responses per query to estimate `pass@1` accuracy, but does not specify the specific number of responses per benchmark. In the tables below, we estimate `pass@1` accuracy with the following number of responses per query:\n\n|   Benchmark   | Number of responses per query |\n|:-------------:|:-----------------------------:|\n|   AIME 2024   |              64               |\n|   MATH-500    |               4               |\n| GPQA Diamond  |               8               |\n| LiveCodeBench |              16               |\n\n\nNote that for benchmarks like AIME24, it is important to sample many responses as there are only 30 problems and this can introduce high variance across repeated runs. The choice of how many responses to sample per prompt likely explains the small differences between our evaluation results and those reported by DeepSeek.\n\n### AIME 2024\n\nWe are able to reproduce Deepseek's reported results on the AIME 2024 benchmark within ~1-3 standard deviations:\n\n| Model                         | AIME 2024 (🤗 LightEval) | AIME 2024 (DeepSeek Reported) |\n|:------------------------------|:------------------------:|:-----------------------------:|\n| DeepSeek-R1-Distill-Qwen-1.5B |           30.7           |             28.9              |\n| DeepSeek-R1-Distill-Qwen-7B   |           50.8           |             55.5              |\n| DeepSeek-R1-Distill-Qwen-14B  |           65.9           |             69.7              |\n| DeepSeek-R1-Distill-Qwen-32B  |           69.7           |             72.6              |\n| DeepSeek-R1-Distill-Llama-8B  |           43.9           |             41.7              |\n| DeepSeek-R1-Distill-Llama-70B |           63.0           |             70.0              |\n\nTo reproduce these results use the following command:\n\n```shell\nNUM_GPUS=1 # Set to 8 for 32B and 70B models\nMODEL=deepseek-ai/{model_name}\nMODEL_ARGS=\"model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\"\nOUTPUT_DIR=data/evals/$MODEL\n\nlighteval vllm $MODEL_ARGS \"lighteval|aime24|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR\n```\n\nAlternatively, you can launch Slurm jobs as follows:\n\n```shell\npython scripts/run_benchmarks.py --model-id {model_id}  --benchmarks aime24\n```\n\n### MATH-500\n\nWe are able to reproduce Deepseek's reported results on the MATH-500 benchmark within ~1-3 standard deviations:\n\n| Model                         | MATH-500 (🤗 LightEval) | MATH-500 (DeepSeek Reported) |\n|:------------------------------|:-----------------------:|:----------------------------:|\n| DeepSeek-R1-Distill-Qwen-1.5B |          83.1           |             83.9             |\n| DeepSeek-R1-Distill-Qwen-7B   |          94.5           |             92.8             |\n| DeepSeek-R1-Distill-Qwen-14B  |          94.1           |             93.9             |\n| DeepSeek-R1-Distill-Qwen-32B  |          95.6           |             94.3             |\n| DeepSeek-R1-Distill-Llama-8B  |          88.6           |             89.1             |\n| DeepSeek-R1-Distill-Llama-70B |          95.1           |             94.5             |\n\nTo reproduce these results use the following command:\n\n```shell\nexport VLLM_WORKER_MULTIPROC_METHOD=spawn\nNUM_GPUS=1 # Set to 8 for 32B and 70B models\nMODEL=deepseek-ai/{model_name}\nMODEL_ARGS=\"model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\"\nOUTPUT_DIR=data/evals/$MODEL\n\nlighteval vllm $MODEL_ARGS \"lighteval|math_500|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR\n```\n\nAlternatively, you can launch Slurm jobs as follows:\n\n```shell\npython scripts/run_benchmarks.py --model-id {model_id}  --benchmarks math_500\n```\n\n### GPQA Diamond\n\nWe are able to reproduce Deepseek's reported results on the GPQA Diamond benchmark within ~1-3 standard deviations:\n\n| Model                         | GPQA Diamond (🤗 LightEval) | GPQA Diamond (DeepSeek Reported) |\n|:------------------------------|:---------------------------:|:--------------------------------:|\n| DeepSeek-R1-Distill-Qwen-1.5B |            35.8             |               33.8               |\n| DeepSeek-R1-Distill-Qwen-7B   |            50.5             |               49.1               |\n| DeepSeek-R1-Distill-Qwen-14B  |            61.5             |               59.1               |\n| DeepSeek-R1-Distill-Qwen-32B  |            63.1             |               62.1               |\n| DeepSeek-R1-Distill-Llama-8B  |            46.7             |               49.0               |\n| DeepSeek-R1-Distill-Llama-70B |            67.4             |               65.2               |\n\nTo reproduce these results use the following command:\n\n```shell\nexport VLLM_WORKER_MULTIPROC_METHOD=spawn\nNUM_GPUS=1 # Set to 8 for 32B and 70B models\nMODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\nMODEL_ARGS=\"model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\"\nOUTPUT_DIR=data/evals/$MODEL\n\nlighteval vllm $MODEL_ARGS \"lighteval|gpqa:diamond|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR\n```\n\n```shell\npython scripts/run_benchmarks.py --model-id {model_id}  --benchmarks gpqa\n```\n\n### LiveCodeBench\n\nWe are able to reproduce Deepseek's reported results on the LiveCodeBench code generation benchmark within ~1-3 standard deviations:\n\n| Model                         | LiveCodeBench (🤗 LightEval) | LiveCodeBench (DeepSeek Reported) |\n|:------------------------------|:----------------------------:|:---------------------------------:|\n| DeepSeek-R1-Distill-Qwen-1.5B |             16.1             |               16.9                |\n| DeepSeek-R1-Distill-Qwen-7B   |             37.4             |               37.6                |\n| DeepSeek-R1-Distill-Qwen-14B  |             51.3             |               53.1                |\n| DeepSeek-R1-Distill-Qwen-32B  |             56.0             |               57.2                |\n| DeepSeek-R1-Distill-Llama-8B  |             37.4             |               39.6                |\n| DeepSeek-R1-Distill-Llama-70B |             55.9             |               57.5                |\n\nTo reproduce these results use the following command:\n\n```shell\nNUM_GPUS=1 # Set to 8 for 32B and 70B models, or data_parallel_size=8 with the smaller models for speed\nMODEL=deepseek-ai/{model_name}\nMODEL_ARGS=\"model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\"\nOUTPUT_DIR=data/evals/$MODEL\n\nlighteval vllm $MODEL_ARGS \"extended|lcb:codegeneration|0|0\" \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR\n```\n\n```shell\npython scripts/run_benchmarks.py --model-id {model_id}  --benchmarks lcb\n```\n\n## Data generation\n\n### Generate data from a smol distilled R1 model\n\nThe following example can be run in 1xH100. \nFirst install the following dependencies:\n\n```shell\nuv pip install \"distilabel[vllm]>=1.5.2\"\n```\n\nNow save the following snippet into a file named `pipeline.py` and run it with `python pipeline.py`. It will generate 4 outputs for each of the 10 examples (change the username for the repository to your org/user name):\n\n```python\nfrom datasets import load_dataset\nfrom distilabel.models import vLLM\nfrom distilabel.pipeline import Pipeline\nfrom distilabel.steps.tasks import TextGeneration\n\n\nprompt_template = \"\"\"\\\nYou will be given a problem. Please reason step by step, and put your final answer within \\boxed{}:\n{{ instruction }}\"\"\"\n\ndataset = load_dataset(\"AI-MO/NuminaMath-TIR\", split=\"train\").select(range(10))\n\nmodel_id = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\"  # Exchange with another smol distilled r1\n\nwith Pipeline(\n    name=\"distill-qwen-7b-r1\",\n    description=\"A pipeline to generate data from a distilled r1 model\",\n) as pipeline:\n\n    llm = vLLM(\n        model=model_id,\n        tokenizer=model_id,\n        extra_kwargs={\n            \"tensor_parallel_size\": 1,\n            \"max_model_len\": 8192,\n        },\n        generation_kwargs={\n            \"temperature\": 0.6,\n            \"max_new_tokens\": 8192,\n        },\n    )\n    prompt_column = \"problem\"\n    text_generation = TextGeneration(\n        llm=llm, \n        template=prompt_template,\n        num_generations=4,\n        input_mappings={\"instruction\": prompt_column} if prompt_column is not None else {}\n    )\n\n\nif __name__ == \"__main__\":\n    distiset = pipeline.run(dataset=dataset)\n    distiset.push_to_hub(repo_id=\"username/numina-deepseek-r1-qwen-7b\")\n```\n\nTake a look at the sample dataset at [HuggingFaceH4/numina-deepseek-r1-qwen-7b](https://huggingface.co/datasets/HuggingFaceH4/numina-deepseek-r1-qwen-7b).\n\n\n### Generate data from DeepSeek-R1\n\nTo run the bigger DeepSeek-R1, we used 2 nodes, each with 8×H100 GPUs using the slurm file present in this repo at `slurm/generate.slurm`. First, install the dependencies:\n\n(for now we need to install the vllm dev wheel that [fixes the R1 cuda graph capture](https://github.com/vllm-project/vllm/commits/221d388cc5a836fa189305785ed7e887cea8b510/csrc/moe/moe_align_sum_kernels.cu))\n```shell\npip install https://wheels.vllm.ai/221d388cc5a836fa189305785ed7e887cea8b510/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu121\n\nuv pip install \"distilabel[vllm,ray,openai]>=1.5.2\"\n```\n\nAnd then run the following command:\n\n```shell\nsbatch slurm/generate.slurm \\\n    --hf-dataset AI-MO/NuminaMath-TIR \\\n    --temperature 0.6 \\\n    --prompt-column problem \\\n    --model deepseek-ai/DeepSeek-R1 \\\n    --hf-output-dataset username/r1-dataset\n```\n\n> [!NOTE]  \n> While the job is running, you can setup an SSH tunnel through the cluster login node to access the Ray dashboard from your computer running `ssh -L 8265:ray_ip_head_node:8265 <login_node>`, then browsing `http://localhost:8265`\n\n\n### Data decontamination\n\nFollowing [s1: Simple test-time scaling](https://huggingface.co/papers/2501.19393) the data can be decontaminated using the script at: [scripts/decontaminate.py](./scripts/decontaminate.py), which decontaminates a dataset using 8-grams and deduplicate the data. Sample run:\n\n```shell\npython scripts/decontaminate.py \\\n    --dataset \"open-r1/verifiable-coding-problems-python\" \\\n    --problem_column problem \\\n    --cleanup\n```\n\nIt will decontaminate against the benchmark datasets, and remove the contaminated samples afterwards. If no argument `--new_dataset_name` is provided, the same dataset will be reused, adding a `_decontaminated`. It runs against the prompt, which for this dataset is the column `problem`, but a different one can be provided.\n\nArguments for the script:\n\n```shell\nusage: decontaminate.py [-h] --dataset DATASET [--split SPLIT] [--ngram_size NGRAM_SIZE] [--problem_column PROBLEM_COLUMN] [--cleanup] [--new_dataset_name NEW_DATASET_NAME]\n\noptions:\n  -h, --help            show this help message and exit\n  --dataset DATASET     Name of the dataset to check for contamination.\n  --split SPLIT         Split to check for contamination, defaults to `train`.\n  --ngram_size NGRAM_SIZE\n                        Size of n-grams to build, defaults to 8.\n  --problem_column PROBLEM_COLUMN\n                        Name of the column containing the problem (prompt).\n  --cleanup           Whether to remove the contaminated rows before pushing the dataset.\n  --new_dataset_name NEW_DATASET_NAME\n                        New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name.\n```\n\n## Contributing\n\nContributions are welcome. Please refer to https://github.com/huggingface/open-r1/issues/23.\n\n## Acknowledgements\n\nThis project is built with the collective efforts of many groups and individuals in the open AI community. We are especially grateful to the vLLM and SGLang teams for creating high-performance tooling to scale the rollouts of GRPO. We also thank the teams at [OpenThoughts](https://www.open-thoughts.ai), [Prime Intellect](https://www.primeintellect.ai), and [General Reasoning](https://gr.inc) for creating and sharing high-quality datasets for reasoning.\n\n## Citation\n\nIf you find this project is useful in your own work, please consider citing as follows:\n\n```\n@misc{openr1,\n    title = {Open R1: A fully open reproduction of DeepSeek-R1},\n    url = {https://github.com/huggingface/open-r1},\n    author = {{Hugging Face}},\n    month = {January},\n    year = {2025}\n}\n```\n"
  },
  {
    "path": "recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml",
    "content": "# Model arguments\nmodel_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\nmodel_revision: main\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\n# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works\nchat_template: \"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<｜User｜>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\\\n' + '```json' + '\\\\n' + tool['function']['arguments'] + '\\\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\\\n' + '```json' + '\\\\n' + tool['function']['arguments'] + '\\\\n' + '```' + '<｜tool▁call▁end｜>'}}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜>'}}{% endif %}\"\ndataset_name: open-r1/OpenR1-Math-220k\ndataset_prompt_column: problem\nsystem_prompt: \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n\n# GRPO trainer config\nbf16: true\nuse_vllm: true\ndo_eval: false\ngradient_accumulation_steps: 4\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nhub_model_id: DeepSeek-R1-Distill-Qwen-1.5B-GRPO\nhub_strategy: every_save\nlearning_rate: 1.0e-06\nlog_completions: true\nlog_level: info\nlogging_first_step: true\nlogging_steps: 1\nlogging_strategy: steps\nlr_scheduler_type: cosine_with_min_lr\nlr_scheduler_kwargs:\n  min_lr_rate: 0.1\nmax_prompt_length: 512\nmax_completion_length: 2048\nmax_steps: -1\nnum_generations: 16\nnum_train_epochs: 1\noutput_dir: data/DeepSeek-R1-Distill-Qwen-1.5B-GRPO\noverwrite_output_dir: true\nper_device_eval_batch_size: 16\nper_device_train_batch_size: 16\npush_to_hub: true\nreport_to:\n- wandb\nreward_funcs:\n- accuracy\n- format\n- tag_count\nreward_weights:\n- 1.0\n- 1.0\n- 1.0\nsave_strategy: \"epoch\"\nsave_total_limit: 1\nseed: 42\ntemperature: 0.7\nuse_liger_kernel: true\nwarmup_ratio: 0.1\n"
  },
  {
    "path": "recipes/OlympicCoder-32B/sft/config_v00.00.yaml",
    "content": "# Config for 16 nodes of 8 H100s with FSDP1\n# Model arguments\nmodel_name_or_path: Qwen/Qwen2.5-Coder-32B-Instruct\nmodel_revision: main\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\ndataset_name: open-r1/codeforces-cots\ndataset_config: solutions_decontaminated\ndataset_num_proc: 12\n\n# SFT trainer config\nbf16: true\ndo_eval: false\neval_strategy: 'no'\ngradient_accumulation_steps: 1\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nhub_always_push: true\nhub_model_id: OlympicCoder-32B\nhub_strategy: every_save\nlearning_rate: 4.0e-05\nlog_level: info\nlogging_steps: 1\nlogging_strategy: steps\nlr_scheduler_type: cosine_with_min_lr\nlr_scheduler_kwargs:\n  min_lr_rate: 0.1\npacking: false\nmax_grad_norm: 0.2\nmax_length: 22528 # we were unable to train at 32k due to OOM. See https://github.com/huggingface/transformers/issues/35983 for context parallelism support.\nmax_steps: -1\nnum_train_epochs: 10\noptim: paged_adamw_8bit\noutput_dir: data/OlympicCoder-32B\noverwrite_output_dir: true\nper_device_eval_batch_size: 1\nper_device_train_batch_size: 1\npush_to_hub: true\nreport_to:\n- wandb\nsave_only_model: true # needed to bypass FSDP errors with saving paged optimizers\nsave_strategy: epoch\nsave_total_limit: 1\nseed: 42\nuse_liger_kernel: false # fails on multi-node\nwarmup_ratio: 0.03"
  },
  {
    "path": "recipes/OlympicCoder-7B/sft/config_v00.00.yaml",
    "content": "# Config for 1 node of 8 H100s with DeepSpeed ZeRO-3\n# Model arguments\nmodel_name_or_path: Qwen/Qwen2.5-Coder-7B-Instruct\nmodel_revision: main\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\ndataset_name: open-r1/codeforces-cots\ndataset_config: solutions_decontaminated\ndataset_num_proc: 48\n\n# SFT trainer config\nbf16: true\ndo_eval: false\neval_strategy: 'no'\ngradient_accumulation_steps: 8\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nhub_model_id: open-r1/OlympicCoder-7B\nhub_strategy: every_save\nlearning_rate: 1.0e-05\nlog_level: info\nlogging_steps: 1\nlogging_strategy: steps\nlr_scheduler_type: cosine_with_min_lr\nlr_scheduler_kwargs:\n  min_lr_rate: 0.1\npacking: false\nmax_grad_norm: 0.2\nmax_length: 32768\nmax_steps: -1\nnum_train_epochs: 10\noutput_dir: data/OlympicCoder-7B\noverwrite_output_dir: true\nper_device_eval_batch_size: 1\nper_device_train_batch_size: 2\npush_to_hub: true\nreport_to:\n- wandb\nsave_strategy: epoch\nsave_total_limit: 1\nseed: 42\nuse_liger_kernel: true\nwarmup_ratio: 0.03"
  },
  {
    "path": "recipes/OpenR1-Distill-7B/sft/config_distill.yaml",
    "content": "# Config for 1 node of 8 x H100s (80GB)\n# Model arguments\nmodel_name_or_path: open-r1/Qwen2.5-Math-7B-RoPE-300k\nmodel_revision: main\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\nchat_template: \"{%- if tools %}\\n    {{- '<|im_start|>system\\\\n' }}\\n    {%- if messages[0]['role'] == 'system' %}\\n        {{- messages[0]['content'] }}\\n    {%- else %}\\n        {{- 'You are Open-R1, a language model trained by Hugging Face to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines.' }}\\n    {%- endif %}\\n    {{- \\\"\\\\n\\\\n# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\\" }}\\n    {%- for tool in tools %}\\n        {{- \\\"\\\\n\\\" }}\\n        {{- tool | tojson }}\\n    {%- endfor %}\\n    {{- \\\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\\"name\\\\\\\": <function-name>, \\\\\\\"arguments\\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\\" }}\\n{%- else %}\\n    {%- if messages[0]['role'] == 'system' %}\\n        {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n    {%- else %}\\n        {{- '<|im_start|>system\\\\nYou are Open-R1, a language model trained by Hugging Face to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines.<|im_end|>\\\\n' }}\\n    {%- endif %}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\"\ndataset_name: open-r1/Mixture-of-Thoughts\ndataset_config: all\ndataset_num_proc: 12\neos_token: <|im_end|>\n\n# SFT trainer config\nbf16: true\ndo_eval: false\neval_strategy: 'no'\ngradient_accumulation_steps: 8\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nhub_model_id: OpenR1-Distill-7B\nhub_strategy: every_save\nlearning_rate: 4.0e-05\nlog_level: info\nlogging_steps: 1\nlogging_strategy: steps\nlr_scheduler_type: cosine_with_min_lr\nlr_scheduler_kwargs:\n  min_lr_rate: 0.1\npacking: false\nmax_grad_norm: 0.2\nmax_length: 32768\nmax_steps: -1\nnum_train_epochs: 5\noutput_dir: data/OpenR1-Distill-7B\noverwrite_output_dir: true\nper_device_eval_batch_size: 1\nper_device_train_batch_size: 2\npush_to_hub: true\nreport_to:\n- wandb\nsave_strategy: epoch\nsave_total_limit: 1\nseed: 42\nuse_liger_kernel: true\nwarmup_ratio: 0.03"
  },
  {
    "path": "recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml",
    "content": "# Model arguments\nmodel_name_or_path: Qwen/Qwen2.5-1.5B-Instruct\nmodel_revision: main\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\ndataset_name: open-r1/OpenR1-Math-220k\ndataset_prompt_column: problem\nsystem_prompt: \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n\n# GRPO trainer config\nbf16: true\nuse_vllm: true\ndo_eval: false\ngradient_accumulation_steps: 4\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nhub_model_id: Qwen2.5-1.5B-Open-R1-GRPO\nhub_strategy: every_save\nlearning_rate: 2.0e-05\nlog_completions: true\nlog_level: info\nlogging_first_step: true\nlogging_steps: 1\nlogging_strategy: steps\nlr_scheduler_type: cosine\nmax_prompt_length: 512\nmax_completion_length: 1024\nmax_steps: -1\nnum_generations: 16\nnum_train_epochs: 1\noutput_dir: data/Qwen2.5-1.5B-Open-R1-GRPO\noverwrite_output_dir: true\nper_device_eval_batch_size: 16\nper_device_train_batch_size: 16\npush_to_hub: true\nreport_to:\n- wandb\nreward_funcs:\n- accuracy\n- format\n- tag_count\nreward_weights:\n- 1.0\n- 1.0\n- 1.0\nsave_strategy: \"epoch\"\nsave_total_limit: 1\nseed: 42\nwarmup_ratio: 0.1\n"
  },
  {
    "path": "recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml",
    "content": "# Model arguments\nmodel_name_or_path: Qwen/Qwen2.5-1.5B-Instruct\nmodel_revision: main\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\ndataset_name: open-r1/verifiable-coding-problems-python\ndataset_prompt_column: problem_statement\nsystem_prompt: \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n\n# GRPO trainer config\nbeta: 0.01\nbf16: true\nuse_vllm: true\ndo_eval: false\ngradient_accumulation_steps: 4\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nhub_model_id: Qwen2.5-1.5B-Open-R1-Code-GRPO\nhub_strategy: every_save\nlearning_rate: 5.0e-06\nlog_completions: true\nlog_level: info\nlogging_first_step: true\nlogging_steps: 1\nlogging_strategy: steps\nlr_scheduler_type: cosine_with_min_lr\nlr_scheduler_kwargs:\n  min_lr_rate: 0.1\nmax_prompt_length: 1024\nmax_completion_length: 2048\nmax_steps: 500\nnum_generations: 14\nnum_train_epochs: 1\noutput_dir: data/Qwen2.5-1.5B-Open-R1-Code-GRPO\noverwrite_output_dir: true\nper_device_train_batch_size: 16\npush_to_hub: true\nreport_to:\n- wandb\nreward_funcs:\n- code\n- format\nreward_weights:\n- 1.0\n- 0.1\nsave_strategy: \"steps\"\nsave_steps: 50\nsave_total_limit: 1\nseed: 42\ntemperature: 1.0\nwarmup_ratio: 0.03"
  },
  {
    "path": "recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml",
    "content": "# Model arguments\nmodel_name_or_path: Qwen/Qwen2.5-1.5B-Instruct\nmodel_revision: main\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\ndataset_name: open-r1/ioi\ndataset_prompt_column: problem\nsystem_prompt: \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n\n# GRPO trainer config\nbeta: 0.01\nbf16: true\nuse_vllm: true\ndo_eval: false\ngradient_accumulation_steps: 4\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nhub_model_id: Qwen2.5-1.5B-Open-R1-Code-GRPO\nhub_strategy: every_save\nlearning_rate: 5.0e-06\nlog_completions: true\nlog_level: info\nlogging_first_step: true\nlogging_steps: 1\nlogging_strategy: steps\nlr_scheduler_type: cosine_with_min_lr\nlr_scheduler_kwargs:\n  min_lr_rate: 0.1\nmax_prompt_length: 1024\nmax_completion_length: 2048\nmax_steps: 500\nnum_generations: 14\nnum_train_epochs: 1\noutput_dir: data/Qwen2.5-1.5B-Open-R1-Code-GRPO\noverwrite_output_dir: true\nper_device_train_batch_size: 16\npush_to_hub: true\nreport_to:\n- wandb\nsave_strategy: \"steps\"\nsave_steps: 50\nsave_total_limit: 1\nseed: 42\ntemperature: 1.0\nwarmup_ratio: 0.03\n# ioi specific config\ncode_language: cpp\nreward_funcs:\n- ioi_code\n- code_format\n- format\nreward_weights:\n- 1.0\n- 0.1\n- 0.1\n# for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating\n# otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions\ncode_eval_test_batch_size: 3"
  },
  {
    "path": "recipes/Qwen2.5-Coder-7B-Instruct/grpo/config_codeforces.yaml",
    "content": "# Model arguments\nmodel_name_or_path: Qwen/Qwen2.5-Coder-7B-Instruct\nmodel_revision: main\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n# Data training arguments\ndataset_name: open-r1/codeforces\ndataset_prompt_column: prompt\ndataset_config: verifiable-prompts\ndataset_test_split: test\ndataset_train_split: train\n\nsystem_prompt: \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n\n# GRPO trainer config\ncallbacks:\n- push_to_hub_revision\nbenchmarks:\n- lcb_v4\nbeta: 0.0\nloss_type: dr_grpo\nscale_rewards: false\nbf16: true\ndo_eval: false\neval_strategy: \"no\"\nuse_vllm: true\nvllm_device: auto\nvllm_gpu_memory_utilization: 0.7\ngradient_accumulation_steps: 32\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nhub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-Codeforces-GRPO\nhub_model_revision: v01.00\nhub_strategy: every_save\nlearning_rate: 1.0e-06\nlog_completions: true\nlog_level: info\nlogging_first_step: true\nlogging_steps: 1\nlogging_strategy: steps\nlr_scheduler_type: constant_with_warmup\nmax_grad_norm: 0.2\nmax_prompt_length: 2000\nmax_completion_length: 8192\nmax_steps: -1\nnum_generations: 16\n# aiming for 1k optimization steps\n# total_samples_per_batch = num_gpus * grad_accumulation_steps * per_device_batch_size = 8 * 32 * 4 = 1024\n# unique_prompts_per_batch = total_samples_per_batch / num_generations = 1024 / 16 = 64\n# #dataset ~= 16k (8k * 2, for python and cpp)\n# global_steps_per_epoch = #dataset / unique_prompts_per_batch = 16k / 64 ~= 250\n# epochs_for_1k_steps = 1000/250 = 4 epochs\nnum_train_epochs: 4\noutput_dir: data/Qwen2.5-Coder-7B-Instruct-Codeforces-GRPO_v01.00\noverwrite_output_dir: true\nper_device_train_batch_size: 4\npush_to_hub: true\nreport_to:\n- wandb\nreward_funcs:\n- cf_code\n- code_format\nreward_weights:\n- 1.0\n- 0.1\nsave_strategy: \"steps\"\nsave_steps: 0.05\nsave_total_limit: 1\nseed: 42\ntemperature: 0.7\nwandb_entity: huggingface\nwandb_project: open-r1\nwarmup_ratio: 0.1\n\nmask_truncated_completions: true\n# for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating\n# otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions\ncode_eval_test_batch_size: -1\ncode_eval_scoring_mode: weighted_sum"
  },
  {
    "path": "recipes/README.md",
    "content": "# Post-training recipes\n\n## OpenR1 Distill 7B\n\nTo train the OpenR1 Distill 7B model, run:\n\n```\nsbatch --nodes=1 slurm/train.slurm --model OpenR1-Distill-7B --task sft --config distill --accelerator zero3\n```\n\n## OlympicCoder\n\nTo train the OlympicCoder models, run:\n\n```\n# 7B\nsbatch --nodes=1 slurm/train.slurm --model OlympicCoder-7B --task sft --config v00.00 --accelerator zero3\n\n# 32B\nsbatch --nodes=16 slurm/train.slurm --model OlympicCoder-32B --task sft --config v00.00 --accelerator fsdp\n```\n\nNote that we found it necessary to switch to FSDP1 and paged AdamW 8-bit for the 32B model in order to fit the largest possible context size."
  },
  {
    "path": "recipes/accelerate_configs/ddp.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\ngpu_ids: all\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "recipes/accelerate_configs/fsdp.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nfsdp_config:\n  fsdp_activation_checkpointing: false # Need fix from: https://github.com/huggingface/transformers/pull/36610\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_backward_prefetch: BACKWARD_PRE\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_forward_prefetch: true\n  fsdp_offload_params: false\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sync_module_states: true\n  fsdp_use_orig_params: true\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false"
  },
  {
    "path": "recipes/accelerate_configs/zero2.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  deepspeed_multinode_launcher: standard\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false"
  },
  {
    "path": "recipes/accelerate_configs/zero3.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  deepspeed_multinode_launcher: standard\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: true\n  zero3_save_16bit_model: true\n  zero_stage: 3\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "recipes/dataset_filtering/config_demo.yaml",
    "content": "# Model arguments\nmodel_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\nmodel_revision: main\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\n# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works\nchat_template: \"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<｜User｜>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\\\n' + '```json' + '\\\\n' + tool['function']['arguments'] + '\\\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\\\n' + '```json' + '\\\\n' + tool['function']['arguments'] + '\\\\n' + '```' + '<｜tool▁call▁end｜>'}}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜>'}}{% endif %}\"\ndataset_name: open-r1/OpenR1-Math-220k\ndataset_prompt_column: problem\nsystem_prompt: \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n\n# Generation arguments\nmax_completion_length: 2048\nnum_generations: 8\ntemperature: 0.7\ntop_p: 0.95\n\n# Reward func arguments\nreward_funcs:\n- accuracy\nreward_weights:\n- 1.0\n\n# Filtering arguments. Samples with a pass rate outside the interval `pass_rate_min < x < pass_rate_max` will be filtered.  \npass_rate_min: 0.2\npass_rate_max: 0.8\n"
  },
  {
    "path": "recipes/dataset_filtering/filter_dapo.yaml",
    "content": "# Model arguments\nmodel_name_or_path: open-r1/R1-Distill-Qwen-Math-7B\nmodel_revision: v03.00-step-000008190\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\n# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works\ndataset_name: open-r1/DAPO-Math-17k-Processed\ndataset_config: all\ndataset_split: train\n\n# Generation arguments\nmax_completion_length: 32000\nnum_generations: 8\ntemperature: 1.0\n\n# Reward func arguments\nreward_funcs:\n- accuracy\nreward_weights:\n- 1.0\n\n# Filtering arguments. Samples with mean reward outside of low / high will be filtered\npass_rate_min: 0.1\npass_rate_max: 0.6\n\noutput_dataset_name: open-r1/DAPO-Math-17k-Processed-R1-Distill-Qwen-Math-7B-v03.00-step-000008190-filter\n"
  },
  {
    "path": "recipes/dataset_filtering/filter_python.yaml",
    "content": "# Model arguments\nmodel_name_or_path: open-r1/R1-Distill-Qwen-Math-7B-Merges\nmodel_revision: v00.00-step-000003660_v01.00-step-000002600_weights-0.50-0.50\ntorch_dtype: bfloat16\nattn_implementation: flash_attention_2\n\n# Data training arguments\n# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works\ndataset_name: open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled\ndataset_prompt_column: problem\n\n# Generation arguments\nmax_completion_length: 16000\nnum_generations: 8\ntemperature: 0.7\n\n# Reward func arguments\nreward_funcs:\n- binary_code\nreward_weights:\n- 1.0\ne2b_router_url: ip-10-53-85-92:8000\n\n# Filtering arguments. Samples with mean reward outside of low / high will be filtered\npass_rate_min: 0.1\npass_rate_max: 0.6\n"
  },
  {
    "path": "scripts/benchmark_e2b.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nBenchmark script for the code_reward function with E2B.\n\nThis script measures the performance of the code_reward function with varying numbers\nof samples and parallelization levels.\n\nEach sample is a CodeForces problem with a gold standard solution that is executed against a set of public test cases.\n\"\"\"\n\nfrom datasets import load_dataset\nimport time\nfrom tqdm.auto import tqdm\n\nfrom dotenv import load_dotenv\nload_dotenv()\n\nfrom open_r1.rewards import code_reward\n\ndef benchmark_code_reward(example):\n    start_time = time.time()\n    test_completions = [[{\"content\": example[\"gold_standard_solution\"]}]]\n    reward_kwargs = {\"verification_info\": [example[\"verification_info\"]]}\n    rewards = code_reward(test_completions, **reward_kwargs)\n    end_time = time.time()\n    example[\"test_reward\"] = rewards[0]\n    example[\"reward_time\"] = end_time - start_time\n    return example\n\nif __name__ == \"__main__\":\n    parallel_dict = {\n        16:[1,4,16],\n        64:[4,16, 64],\n        256:[16, 64, 96], # cap at 96 as PRO account is limited to 100\n    }\n    # Store results for table formatting\n    results = []\n    \n    for num_samples in tqdm([16, 64,256], desc=\"Benchmarking samples\"):\n        for num_parallel in parallel_dict[num_samples]:\n            code_dataset = load_dataset(\"open-r1/verifiable-coding-problems-python_decontaminated\")\n            code_dataset = code_dataset[\"train\"].shuffle(seed=42).select(range(num_samples))\n\n            test_completions = [[{\"content\": example[\"gold_standard_solution\"]}] for example in code_dataset]\n            reward_kwargs = {\"verification_info\": [example[\"verification_info\"] for example in code_dataset]}\n\n            start_time = time.time()\n            rewards = code_reward(test_completions, num_parallel=num_parallel, **reward_kwargs)\n            execution_time = time.time() - start_time\n            \n            # Calculate some statistics about rewards\n            mean_reward = sum(rewards) / len(rewards)\n            min_reward = min(rewards)\n            max_reward = max(rewards)\n            \n            # Store results\n            results.append({\n                \"num_samples\": num_samples,\n                \"num_parallel\": num_parallel,\n                \"execution_time\": execution_time,\n                \"mean_reward\": mean_reward,\n                \"min_reward\": min_reward,\n                \"max_reward\": max_reward\n            })\n    \n    print(\"\\n## Benchmark Results\\n\")\n    print(\"| Sample Size | Parallelization | Execution Time (s) | Mean Reward | Min Reward | Max Reward |\")\n    print(\"|:-----------:|:---------------:|------------------:|:-----------:|:-----------:|:-----------:|\")\n    \n    for result in results:\n        print(f\"| {result['num_samples']:^11} | {result['num_parallel']:^15} | {result['execution_time']:17.2f} | {result['mean_reward']:^11.4f} | {result['min_reward']:^11.4f} | {result['max_reward']:^11.4f} |\")\n    \n"
  },
  {
    "path": "scripts/decontaminate.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis script is used to decontaminate a dataset by checking for n-gram overlap with other datasets.\nIt uses the same approach presented in https://huggingface.co/papers/2501.19393,\nas found in: https://github.com/simplescaling/s1/blob/main/data/decontaminate_util.py\n\nUsage:\n\npython scripts/decontaminate.py \\\n    --dataset open-r1/verifiable-coding-problems-python \\\n    --split train \\\n    --ngram_size 8 \\\n    --problem_column problem \\\n    --cleanup\n\"\"\"\n\nimport collections\n\nfrom tqdm import tqdm\n\n\ndef normalize_string(text: str) -> str:\n    \"\"\"Basic string normalization.\"\"\"\n    # Convert to lowercase and normalize whitespace\n    text = text.lower().strip()\n    # Replace multiple spaces with single space\n    text = \" \".join(text.split())\n    return text\n\n\ndef word_ngrams(text: str, n: int) -> list:\n    \"\"\"Generate word-level n-grams from text.\"\"\"\n    words = text.split()\n    return [\" \".join(words[i : i + n]) for i in range(len(words) - n + 1)]\n\n\ndef build_ngram_lookup(documents: list[str], ngram_size: int = 8) -> dict[str, set[int]]:\n    \"\"\"Build ngram lookup for documents.\"\"\"\n    lookup = collections.defaultdict(set)\n\n    for doc_id, document in enumerate(tqdm(documents)):\n        normalized_text = normalize_string(document)\n        ngrams = word_ngrams(normalized_text, ngram_size)\n        for ngram in ngrams:\n            lookup[ngram].add(doc_id)\n\n    return lookup\n\n\ndef build_ngram_single(document: str, ngram_size: int = 8) -> set[str]:\n    normalized_text = normalize_string(document)\n    ngrams = word_ngrams(normalized_text, ngram_size)\n\n    return set(ngrams)\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", type=str, required=True, help=\"Name of the dataset to check for contamination.\")\n    parser.add_argument(\"--config\", type=str, default=None, help=\"Name of the dataset config to load.\")\n    parser.add_argument(\"--split\", type=str, default=\"train\", help=\"Split to check for contamination, defaults to `train`.\")\n    parser.add_argument(\"--ngram_size\", type=int, default=8, help=\"Size of n-grams to build, defaults to 8.\")\n    parser.add_argument(\n        \"--problem_column\", type=str, default=\"problem\", help=\"Name of the column containing the problem (prompt).\"\n    )\n    parser.add_argument(\n        \"--cleanup\",\n        action=\"store_true\",\n        help=\"Whether to remove the contaminated rows before pushing the dataset.\",\n    )\n    parser.add_argument(\n        \"--new_dataset_name\",\n        type=str,\n        default=None,\n        help=\"New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name.\"\n    )\n    args = parser.parse_args()\n\n    from datasets import load_dataset, Dataset\n\n    # Load the dataset to check for contamination\n    ds = load_dataset(args.dataset, name=args.config, split=args.split)\n\n    eval_datasets = {\n        \"aime_2024\": (load_dataset(\"HuggingFaceH4/aime_2024\", split=\"train\"), \"problem\"),\n        \"aime_2025\": (load_dataset(\"yentinglin/aime_2025\", split=\"train\"), \"problem\"),\n        \"math_500\": (load_dataset(\"HuggingFaceH4/MATH-500\", split=\"test\"), \"problem\"),\n        \"gpqa\": (load_dataset(\"Idavidrein/gpqa\", \"gpqa_diamond\", split=\"train\", trust_remote_code=True), \"Question\"),\n        \"lcb\": (\n            load_dataset(\n                \"livecodebench/code_generation_lite\", split=\"test\", version_tag=\"v4_v5\", trust_remote_code=True\n            ),\n            \"question_content\",\n        ),\n    }\n    ngram_lookups = {}\n    for ds_name, (eval_dataset, problem_col) in eval_datasets.items():\n        ngram_lookups[ds_name] = build_ngram_lookup(eval_dataset[problem_col], ngram_size=args.ngram_size)\n\n    for eval_name, ngram_lookup in ngram_lookups.items():\n        # Update the ngram_lookup variable for each dataset\n        def find_contaminated(row):\n            # For each example we have to build the ngrams and check for all of them on each row\n            ngrams = build_ngram_single(row[args.problem_column], ngram_size=args.ngram_size)\n            row[f\"contaminated_{eval_name}\"] = any(set(ngram in ngram_lookup for ngram in ngrams))\n            return row\n\n        ds = ds.map(find_contaminated, num_proc=8)\n\n    # Allow cleaning up via CLI args (removing the contaminated examples and dropping the columns)\n    def cleanup(dataset: Dataset) -> Dataset:\n        initial_size = len(dataset)\n        contamination_cols = [col for col in dataset.column_names if col.startswith(\"contaminated_\")]\n        for col in contamination_cols:\n            if col.startswith(\"contaminated_\"):\n                size_prior = len(dataset)\n                dataset = dataset.filter(lambda x: not x[col], num_proc=8)\n                if len(dataset) < size_prior:\n                    print(f\"Removed {size_prior - len(dataset)} samples from '{col.replace('contaminated_', '')}'\")\n        dataset = dataset.remove_columns(contamination_cols)\n        print(f\"Initial size: {initial_size}, Final size: {len(dataset)}\")\n        return dataset\n\n    if args.cleanup:\n        ds = cleanup(ds)\n\n    new_ds_name = args.new_dataset_name or f\"{args.dataset}_decontaminated\"\n    config_name = args.config if args.config is not None else \"default\"\n    url = ds.push_to_hub(new_ds_name, config_name=config_name, split=\"train\")\n    print(f\"Decontaminated dataset: {url}\")\n"
  },
  {
    "path": "scripts/e2b_router.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport asyncio\nfrom fastapi import FastAPI\nfrom pydantic import BaseModel, ConfigDict\nfrom typing import  Optional\nfrom fastapi import FastAPI, Request\nimport argparse\nimport asyncio\nfrom fastapi import FastAPI\nimport uvicorn\nfrom e2b_code_interpreter.models import Execution\nfrom dotenv import load_dotenv\nfrom e2b_code_interpreter import AsyncSandbox\n\nload_dotenv()\n\nclass BatchRequest(BaseModel):\n    \"\"\"\n    BatchRequest is a data model representing a batch processing request.\n\n    Attributes:\n        scripts (list[str]): A list of script names or paths to be executed.\n        languages (list[str]): The programming languages for each script in the list.\n        timeout (int): The maximum allowed execution time for each script in seconds.\n        request_timeout (int): The maximum allowed time for the entire batch request in seconds.\n    \"\"\"\n    scripts: list[str]\n    languages: list[str]\n    timeout: int\n    request_timeout: int\n\nclass ScriptResult(BaseModel):\n    \"\"\"\n    ScriptResult is a Pydantic model that represents the result of a script execution.\n    Attributes:\n        execution (Optional[Execution]): An optional instance of the `Execution` class \n            that contains details about the script's execution, such as status, output, \n            or any other relevant metadata.\n        exception_str (Optional[str]): An optional string that captures the exception \n            message or details if an error occurred during the script's execution.\n        model_config (ConfigDict): A configuration dictionary that allows arbitrary \n            types to be used within the Pydantic model. This is necessary to support \n            custom types like `Execution` within the model.\n    \"\"\"\n    execution: Optional[Execution]\n    exception_str: Optional[str]\n    \n    # required to allow arbitrary types in pydantic models such as Execution\n    model_config = ConfigDict(arbitrary_types_allowed=True)\n    \ndef create_app(args):\n    \"\"\"\n    Creates and configures a FastAPI application instance.\n    Args:\n        args: An object containing configuration parameters for the application.\n              - num_sandboxes (int): The maximum number of concurrent sandboxes allowed.\n    Returns:\n        FastAPI: A configured FastAPI application instance.\n    The application includes the following endpoints:\n        1. GET /health:\n            - Returns the health status of the application.\n            - Response: {\"status\": \"ok\"}\n        2. POST /execute_batch:\n            - Executes a batch of scripts in an isolated sandbox environment.\n            - Request Body: BatchRequest object containing:\n                - languages (list[str]): The programming languages of the scripts (python or javascript).\n                - timeout (int): The maximum execution time for each script.\n                - request_timeout (int): The timeout for the request itself.\n                - scripts (List[str]): A list of scripts to execute.\n            - Response: A list of ScriptResult objects for each script, containing:\n                - execution: The result of the script execution.\n                - exception_str: Any exception encountered during execution.\n    Notes:\n        - A semaphore is used to limit the number of concurrent sandboxes.\n        - Each script execution is wrapped in a timeout to prevent hanging.\n        - Sandboxes are cleaned up after execution, even in case of errors.\n    \"\"\"\n    app = FastAPI()\n\n    # Instantiate semaphore and attach it to app state\n    app.state.sandbox_semaphore = asyncio.Semaphore(args.max_num_sandboxes)\n\n    @app.get(\"/health\")\n    async def health():\n        return {\"status\": \"ok\"}\n\n    @app.post(\"/execute_batch\")\n    async def execute_batch(batch: BatchRequest, request: Request):\n        semaphore = request.app.state.sandbox_semaphore\n        languages = batch.languages\n        timeout = batch.timeout\n        request_timeout = batch.request_timeout\n        asyncio_timeout = batch.timeout + 1\n        \n        async def run_script(script: str, language: str) -> ScriptResult:\n\n            async with semaphore:\n                try:\n                    sandbox = await AsyncSandbox.create(\n                        timeout=timeout,\n                        request_timeout=request_timeout,\n                    )\n                    execution = await asyncio.wait_for(\n                        sandbox.run_code(script, language=language),\n                        timeout=asyncio_timeout,\n                    )\n                    return ScriptResult(execution=execution, exception_str=None)\n\n                except Exception as e:\n                    return ScriptResult(execution=None, exception_str=str(e))\n                \n                finally:\n                    try:\n                        await sandbox.kill()\n                    except Exception:\n                        pass\n\n        tasks = [run_script(script, lang) for script, lang in zip(batch.scripts, batch.languages)]\n        return await asyncio.gather(*tasks)\n\n    return app\n\n\ndef parse_args():\n    \"\"\"\n    Parse command-line arguments for the e2b_router script.\n\n    Arguments:\n        --host (str): The hostname or IP address to bind the server to. Defaults to \"0.0.0.0\" (binds to all interfaces).\n        --port (int): The port number on which the server will listen. Defaults to 8000.\n        --max_num_sandboxes (int): The maximum number of sandboxes that can be created or managed simultaneously. Defaults to 20.\n\n    Returns:\n        argparse.Namespace: Parsed command-line arguments as an object.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--host\", default=\"0.0.0.0\")\n    parser.add_argument(\"--port\", type=int, default=8000)\n    parser.add_argument(\"--max_num_sandboxes\", type=int, default=20)\n    return parser.parse_args()\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    app = create_app(args)\n\n    uvicorn.run(app, host=args.host, port=args.port)"
  },
  {
    "path": "scripts/generate_reasoning.py",
    "content": "import argparse\nimport asyncio\nimport hashlib\nimport json\nimport os\nimport random\nfrom asyncio import Lock\nfrom typing import Set\n\nfrom datasets import load_dataset\nfrom tqdm.asyncio import tqdm\n\nimport aiofiles\nimport aiohttp\nimport uvloop\n\n\nfile_lock = Lock()\n\n\nasync def generate_completion(session, prompt, args):\n    retry_budget = 10\n    while retry_budget > 0:\n        try:\n            await asyncio.sleep(random.uniform(0.0, 0.1))\n            async with session.post(\n                f\"http://{args.api_addr}/v1/chat/completions\",\n                json={\n                    \"model\": \"default\",\n                    \"messages\": [{\"role\": \"user\", \"content\": prompt}],\n                    \"max_tokens\": args.max_tokens,\n                    \"temperature\": args.temperature,\n                    \"top_p\": args.top_p,\n                },\n                headers={\"Authorization\": \"Bearer EMPTY\"},\n            ) as response:\n                return await response.json(content_type=None)\n        except Exception as e:\n            print(f\"API error (will retry): {e}\")\n            retry_budget -= 1\n            await asyncio.sleep(10)\n    return None\n\n\nasync def process_example(example, session, args, output_file, pbar):\n    prompt = args.prompt_template.format(prompt=example[args.prompt_column])\n\n    try:\n        tasks = [generate_completion(session, prompt, args) for _ in range(args.num_generations)]\n\n        completions = await asyncio.gather(*tasks)\n\n        if any(completion is None for completion in completions):\n            print(f\"Error processing example\")\n            pbar.update(1)\n            return None\n\n        generations = []\n        finish_reasons = []\n        api_metadata = []\n\n        for completion in completions:\n            generations.append(completion[\"choices\"][0][\"message\"][\"content\"])\n            finish_reasons.append(completion[\"choices\"][0][\"finish_reason\"])\n            api_metadata.append(completion[\"usage\"])\n\n        # Combine original dataset fields with generations\n        result = {\n            **example,  # Preserve all original dataset fields\n            \"generations\": generations,\n            \"finish_reasons\": finish_reasons,\n            \"api_metadata\": api_metadata,\n        }\n\n        # Write to file with lock\n        async with file_lock:\n            async with aiofiles.open(output_file, mode=\"a\") as f:\n                await f.write(json.dumps(result) + \"\\n\")\n                await f.flush()\n\n        pbar.set_postfix(active=len(pbar.active_tasks), refresh=False)\n        pbar.update(1)\n\n        return result\n    except Exception as e:\n        print(f\"Error processing example: {e}\")\n        pbar.update(1)\n        return None\n\n\nasync def load_processed_uuids(output_file, uuid_column):\n    processed_uuids = set()\n    if os.path.exists(output_file):\n        async with aiofiles.open(output_file, mode=\"r\") as f:\n            async for line in f:\n                try:\n                    data = json.loads(line)\n                    processed_uuids.add(hashlib.md5(str(data[uuid_column]).encode()).hexdigest())\n                except json.JSONDecodeError:\n                    continue\n    return processed_uuids\n\n\nasync def main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset-name\", type=str, required=True)\n    parser.add_argument(\"--output-file\", type=str, required=True)\n    parser.add_argument(\"--prompt-column\", type=str, required=True)\n    parser.add_argument(\"--uuid-column\", type=str, required=True)\n    parser.add_argument(\"--api-addr\", type=str, default=\"localhost:39876\")\n    parser.add_argument(\"--num-generations\", type=int, default=4)\n    parser.add_argument(\n        \"--prompt-template\",\n        type=str,\n        default=\"You will be given a problem. Please reason step by step, and put your final answer within \\\\boxed{{}}:\\n{prompt}\",\n    )\n    parser.add_argument(\"--temperature\", type=float, default=0.6)\n    parser.add_argument(\"--top-p\", type=float, default=0.95)\n    parser.add_argument(\"--max-tokens\", type=int, default=16384)\n    parser.add_argument(\"--max-concurrent\", type=int, default=1000)\n    args = parser.parse_args()\n\n    dataset = load_dataset(args.dataset_name, split=\"train\").shuffle()\n    processed_uuids = await load_processed_uuids(args.output_file, args.uuid_column)\n    if processed_uuids:\n        print(f\"Found {len(processed_uuids)} already processed examples, resuming from there...\")\n\n    if not os.path.exists(args.output_file):\n        async with aiofiles.open(args.output_file, mode=\"w\") as f:\n            await f.write(\"\")\n\n    active_tasks: Set[asyncio.Task] = set()\n\n    pbar = tqdm(\n        total=len(dataset) - len(processed_uuids),\n        desc=\"Generating responses\",\n        unit=\"row\",\n        mininterval=2,\n        smoothing=0.0001,\n    )\n    pbar.active_tasks = active_tasks\n\n    async with aiohttp.ClientSession(\n        timeout=aiohttp.ClientTimeout(total=60 * 60),\n        connector=aiohttp.TCPConnector(limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60),\n    ) as session:\n        for example in dataset:\n            uuid = hashlib.md5(str(example[args.uuid_column]).encode()).hexdigest()\n            if uuid not in processed_uuids:\n                # Wait if we've hit the concurrency limit\n                while len(active_tasks) >= args.max_concurrent:\n                    done, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)\n                    for task in done:\n                        try:\n                            await task\n                        except Exception as e:\n                            print(f\"Task failed: {e}\")\n\n                task = asyncio.create_task(process_example(example, session, args, args.output_file, pbar))\n                active_tasks.add(task)\n                task.add_done_callback(active_tasks.discard)\n\n                pbar.set_postfix(active=len(active_tasks), refresh=True)\n\n        # Wait for remaining tasks\n        if active_tasks:\n            await asyncio.gather(*active_tasks, return_exceptions=True)\n\n    pbar.close()\n\n\nif __name__ == \"__main__\":\n    uvloop.install()\n    asyncio.run(main())\n"
  },
  {
    "path": "scripts/get_tensor_parallel_size.py",
    "content": "import argparse\nfrom transformers import AutoConfig\nfrom math import gcd\n\ndef get_tensor_parallel_size(model_name: str, revision: str = None, default_tp: int = 8) -> int:\n    try:\n        config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)\n        num_heads = getattr(config, 'num_attention_heads', None)\n\n        if num_heads is not None and num_heads % default_tp != 0:\n            tp = gcd(num_heads, default_tp)\n            return max(tp, 1)\n        else:\n            return default_tp\n    except Exception as e:\n        print(f\"Warning: Failed to fetch config for {model_name}@{revision}: {e}\")\n        return default_tp\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model_name\", type=str, required=True, help=\"Hugging Face model name or path\")\n    parser.add_argument(\"--revision\", type=str, default=None, help=\"Model revision if applicable\")\n    parser.add_argument(\"--default_tp\", type=int, default=8, help=\"Default TP size (usually GPUs per node)\")\n\n    args = parser.parse_args()\n\n    tp = get_tensor_parallel_size(args.model_name, args.revision, args.default_tp)\n    print(tp)\n"
  },
  {
    "path": "scripts/morph_router.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport asyncio\nfrom fastapi import FastAPI\nfrom pydantic import BaseModel, ConfigDict\nfrom typing import Optional, List\nfrom fastapi import FastAPI, Request\nimport uvicorn\nfrom dotenv import load_dotenv\nimport os\n\nload_dotenv()\n\nclass BatchRequest(BaseModel):\n    \"\"\"\n    BatchRequest is a data model representing a batch processing request.\n\n    Attributes:\n        scripts (list[str]): A list of script names or paths to be executed.\n        languages (List[str]): The programming languages for each script in the list.\n        timeout (int): The maximum allowed execution time for each script in seconds.\n        request_timeout (int): The maximum allowed time for the entire batch request in seconds.\n    \"\"\"\n    scripts: List[str]\n    languages: List[str]\n    timeout: int\n    request_timeout: int\n\nclass ScriptResult(BaseModel):\n    \"\"\"\n    ScriptResult is a Pydantic model that represents the result of a script execution.\n    Attributes:\n        text (Optional[str]): The output text from the script execution.\n        exception_str (Optional[str]): An optional string that captures the exception \n            message or details if an error occurred during the script's execution.\n        model_config (ConfigDict): A configuration dictionary that allows arbitrary \n            types to be used within the Pydantic model.\n    \"\"\"\n    text: Optional[str]\n    exception_str: Optional[str]\n    \n    \n    model_config = ConfigDict(arbitrary_types_allowed=True)\n    \ndef create_app(args):\n    \"\"\"\n    Creates and configures a FastAPI application instance for the MorphCloud router.\n    \n    Args:\n        args: An object containing configuration parameters for the application.\n              - max_num_sandboxes (int): The maximum number of concurrent sandboxes allowed.\n              - api_key (str): The MorphCloud API key to use.\n    \n    Returns:\n        FastAPI: A configured FastAPI application instance.\n    \"\"\"\n    app = FastAPI()\n    \n    from morphcloud.api import MorphCloudClient\n    from morphcloud.sandbox import Sandbox\n    \n    app.state.client = MorphCloudClient(api_key=args.api_key)\n    app.state.Sandbox = Sandbox\n\n    app.state.sandbox_semaphore = asyncio.Semaphore(args.max_num_sandboxes)\n\n    @app.get(\"/health\")\n    async def health():\n        return {\"status\": \"ok\"}\n\n    @app.post(\"/execute_batch\")\n    async def execute_batch(batch: BatchRequest, request: Request):\n        semaphore = request.app.state.sandbox_semaphore\n        client = request.app.state.client\n        Sandbox = request.app.state.Sandbox\n        \n        languages = batch.languages\n        timeout = batch.timeout\n        request_timeout = batch.request_timeout\n        asyncio_timeout = batch.timeout + 1\n        \n        async def run_script(script: str, language: str) -> ScriptResult:\n            sandbox = None\n            sandbox_id = \"unknown\"\n\n            async with semaphore:\n                try:\n                    sandbox = await asyncio.to_thread(\n                        Sandbox.new,\n                        client=client,\n                        ttl_seconds=timeout\n                    )\n                    \n                    sandbox_id = getattr(sandbox, 'id', None) or getattr(sandbox._instance, 'id', 'unknown')\n                    \n                    execution = await asyncio.wait_for(\n                        asyncio.to_thread(\n                            sandbox.run_code,\n                            script,\n                            language=language,\n                            timeout=timeout * 1000  \n                        ),\n                        timeout=asyncio_timeout,\n                    )\n                    \n                    if hasattr(execution, 'text') and execution.text:\n                        return ScriptResult(text=execution.text, exception_str=None)\n                    elif hasattr(execution, 'stdout') and execution.stdout:\n                        return ScriptResult(text=execution.stdout, exception_str=None)\n                    else:\n                        return ScriptResult(text=\"\", exception_str=\"No output from execution\")\n\n                except Exception as e:\n                    return ScriptResult(text=None, exception_str=str(e))\n                \n                finally:\n                    if sandbox:\n                        try:\n                            await asyncio.to_thread(sandbox.close)\n                            await asyncio.to_thread(sandbox.shutdown)\n                        except Exception:\n                            pass\n\n        tasks = [run_script(script, lang) for script, lang in zip(batch.scripts, batch.languages)]\n        return await asyncio.gather(*tasks)\n\n    return app\n\ndef parse_args():\n    \"\"\"\n    Parse command-line arguments for the morph_router script.\n\n    Arguments:\n        --host (str): The hostname or IP address to bind the server to. Defaults to \"0.0.0.0\".\n        --port (int): The port number on which the server will listen. Defaults to 8001.\n        --max_num_sandboxes (int): The maximum number of sandboxes that can be created simultaneously. Defaults to 20.\n        --api_key (str): The MorphCloud API key. If not provided, it will be read from the MORPH_API_KEY environment variable.\n\n    Returns:\n        argparse.Namespace: Parsed command-line arguments as an object.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--host\", default=\"0.0.0.0\")\n    parser.add_argument(\"--port\", type=int, default=8001)\n    parser.add_argument(\"--max_num_sandboxes\", type=int, default=20)\n    parser.add_argument(\"--api_key\", default=os.getenv(\"MORPH_API_KEY\"))\n    args = parser.parse_args()\n    \n    if not args.api_key:\n        raise ValueError(\"MorphCloud API key not provided. Please set MORPH_API_KEY environment variable or use --api_key.\")\n    \n    return args\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    app = create_app(args)\n    \n    print(f\"Starting MorphCloud Router on {args.host}:{args.port}\")\n    uvicorn.run(app, host=args.host, port=args.port)"
  },
  {
    "path": "scripts/pass_rate_filtering/README.md",
    "content": "# Pass rate filtering\n\nWe provide support to filter datasets by generating and computing pass rate on veriable tasks\n\nSee `scripts/pass_rate_filtering/compute_pass_rate.py` and `scripts/pass_rate_filtering/launch_filtering.sh` (hardcoded for DAPO at the moment)\n\nBy default the script chunks the dataset, merge can be run using the following snippet (example for DAPO) :\n\nfrom datasets import load_dataset, concatenate_datasets\n\nname = \"open-r1/DAPO-Math-17k-Processed-R1-Distill-Qwen-Math-7B-Merges-v00.02-v01.02-0.3-0.7-filter\"\n\n```python\ngen_datasets = []\nfilt_datasets = []\nfor start in range(0,17400,200):\n    end = start + 200\n    if start == 17200:\n        end = 17398\n    gen_config_name = f\"gen-{start}-{end}\"\n    gen_dataset = load_dataset(name, gen_config_name, revision=\"gen\",  split=\"train\")\n    gen_datasets.append(gen_dataset)\n    \n    filt_config_name = f\"filt-0.1-0.6-{start}-{end}\"\n    filt_dataset = load_dataset(name, filt_config_name, revision=\"pass_rate\",  split=\"train\")\n    filt_datasets.append(filt_dataset)\n    \ngen_dataset = concatenate_datasets(gen_datasets)\ngen_dataset.push_to_hub(name, config_name=\"gen\", split=\"train\")\nprint(gen_dataset)\n\nfilt_dataset = concatenate_datasets(filt_datasets)\nfilt_dataset.push_to_hub(name, config_name=\"default\", split=\"train\")\n\nprint(filt_dataset)\n```"
  },
  {
    "path": "scripts/pass_rate_filtering/compute_pass_rate.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# example usage python scripts/filter_dataset.py --config recipes/dataset_filtering/config_demo.yaml\n\nimport logging\nfrom dataclasses import dataclass\nfrom git import Optional\nimport torch\nimport sys\n\nimport datasets\nimport transformers\nfrom datasets import load_dataset\nfrom transformers import set_seed\n\nfrom open_r1.configs import GRPOConfig, GRPOScriptArguments\nfrom open_r1.rewards import get_reward_funcs\nfrom open_r1.utils import get_tokenizer\nfrom trl import ModelConfig, TrlParser\nfrom trl.data_utils import apply_chat_template\nfrom vllm import LLM, SamplingParams\n\nlogger = logging.getLogger(__name__)\n\n@dataclass\nclass PassRateScriptArguments(GRPOScriptArguments):\n    # we can be lazy and just use the same script args as GRPO\n    output_dataset_name: Optional[str] = None\n    pass_rate_min: float = 0.1\n    pass_rate_max: float = 0.9\n    dataset_start_index: Optional[int] = None\n    dataset_end_index: Optional[int] = None\n    dataset_split: str = \"train\"\n\n\ndef main(script_args, training_args, model_args):\n    # Set seed for reproducibility\n    set_seed(training_args.seed)\n\n    ###############\n    # Setup logging\n    ###############\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n        handlers=[logging.StreamHandler(sys.stdout)],\n    )\n    log_level = training_args.get_process_log_level()\n    logger.setLevel(log_level)\n    datasets.utils.logging.set_verbosity(log_level)\n    transformers.utils.logging.set_verbosity(log_level)\n    transformers.utils.logging.enable_default_handler()\n    transformers.utils.logging.enable_explicit_format()\n\n    logger.info(f\"Model parameters {model_args}\")\n    logger.info(f\"Script parameters {script_args}\")\n    logger.info(f\"Training parameters {training_args}\")\n\n    # Load the dataset\n    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_split)\n    if script_args.dataset_start_index is not None and script_args.dataset_end_index is not None:\n        dataset = dataset.select(range(script_args.dataset_start_index, script_args.dataset_end_index))\n\n    # Get reward functions from the registry\n    reward_funcs = get_reward_funcs(script_args)\n\n    # Format into conversation\n    def make_conversation(example, prompt_column: str = script_args.dataset_prompt_column):\n        example[\"prompt_backup\"] = example[prompt_column]\n        \n        prompt = []\n\n        if training_args.system_prompt is not None:\n            prompt.append({\"role\": \"system\", \"content\": training_args.system_prompt})\n\n        if prompt_column not in example:\n            raise ValueError(f\"Dataset Question Field Error: {prompt_column} is not supported.\")\n\n        prompt.append({\"role\": \"user\", \"content\": example[prompt_column]})\n        return {\"prompt\": prompt}\n\n    dataset = dataset.map(make_conversation)\n    tokenizer = get_tokenizer(model_args, training_args)\n    \n    if \"messages\" in dataset.column_names:\n        dataset = dataset.remove_columns(\"messages\")\n    \n    dataset = dataset.map(apply_chat_template, fn_kwargs={\"tokenizer\": tokenizer})\n    llm = LLM(\n        model=model_args.model_name_or_path,\n        revision=model_args.model_revision,\n        trust_remote_code=model_args.trust_remote_code,\n    )\n\n    sampling_params=SamplingParams(\n        temperature=training_args.temperature,\n        top_p=training_args.top_p,\n        top_k=training_args.top_k,\n        n=training_args.num_generations,\n        max_tokens=training_args.max_completion_length,\n    )\n    \n    def batch_score(examples):\n        prompts = examples[\"prompt\"]\n        \n        outputs = llm.generate(\n            prompts,\n            sampling_params=sampling_params,\n            use_tqdm=False,\n        )\n        repeated_prompts = []\n        reward_completions = []\n        grouped_completions = []\n        for output in outputs:\n            prompt = output.prompt\n            group = []\n            for completion in output.outputs:\n                text = completion.text\n                group.append(text)\n                message = [{\"role\": \"assistant\", \"content\": text}]\n                repeated_prompts.append(prompt)\n                reward_completions.append(message)\n            grouped_completions.append(group)\n        \n        def repeat_each_element_k_times(list_to_repeat: list, k: int) -> list:\n            return [element for item in list_to_repeat for element in [item] * k]\n        \n        rewards_per_func = torch.zeros(len(repeated_prompts), len(reward_funcs))\n        for i, reward_func in enumerate(reward_funcs):\n            keys = [key for key in examples.data.keys() if key not in [\"prompt\", \"completion\"]]\n            reward_kwargs = {key: repeat_each_element_k_times(examples[key], training_args.num_generations) for key in keys}\n            output_reward_func = reward_func(prompts=repeated_prompts, completions=reward_completions, **reward_kwargs)\n            # Convert None values to NaN\n            output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]\n\n            rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32)\n            \n        reshaped_rewards = rewards_per_func.view(-1, training_args.num_generations)\n        \n        examples[\"pass_rate_generations\"] = grouped_completions\n        examples[\"pass_rate_rewards\"] = reshaped_rewards.tolist()\n\n            \n        return examples\n    \n    dataset = dataset.map(batch_score, batched=True, batch_size=64)\n    \n    # we need to restore the prompt for the final dataset\n    def restore_prompt(example):\n        example[\"prompt\"] = example[\"prompt_backup\"]\n        return example\n    \n    dataset = dataset.map(restore_prompt)\n    dataset = dataset.remove_columns(\"prompt_backup\")\n    \n    if script_args.output_dataset_name is not None:\n        output_dataset_name = script_args.output_dataset_name\n    else:\n        model_name = model_args.model_name_or_path\n        if \"/\" in model_name:\n            model_name = model_name.split(\"/\")[-1]\n        model_revision = model_args.model_revision\n    \n        output_dataset_name = f\"{script_args.dataset_name}-{model_name}-{model_revision}-gen\"\n    \n    config_name=\"default\"\n    filtered_config_name = f\"filt-{script_args.pass_rate_min}-{script_args.pass_rate_max}\"\n    \n    if script_args.dataset_start_index is not None and script_args.dataset_end_index is not None:\n        config_name = f\"gen-{script_args.dataset_start_index}-{script_args.dataset_end_index}\"\n        filtered_config_name = f\"{filtered_config_name}-{script_args.dataset_start_index}-{script_args.dataset_end_index}\"\n        \n    dataset.push_to_hub(output_dataset_name, config_name=config_name, revision=\"gen\")\n    \n    def filter_func(example):\n        rewards = example[\"pass_rate_rewards\"]\n        # get the mean of the rewards that are not None\n        mean_reward = torch.nanmean(torch.tensor(rewards, dtype=torch.float32))\n        \n        return script_args.pass_rate_min < mean_reward < script_args.pass_rate_max\n    \n    logger.info(f\"Filtering dataset with low reward threshold {script_args.pass_rate_min} and high reward threshold {script_args.pass_rate_max}\")\n    logger.info(f\"Dataset size before filtering: {dataset}\")\n    dataset = dataset.filter(filter_func)\n    logger.info(f\"Dataset size after filtering: {dataset}\")\n    dataset.push_to_hub(output_dataset_name, config_name=filtered_config_name, revision=\"pass_rate\")\n    \n    \n\nif __name__ == \"__main__\":\n    parser = TrlParser((PassRateScriptArguments, GRPOConfig, ModelConfig))\n    script_args, training_args, model_args = parser.parse_args_and_config()\n    main(script_args, training_args, model_args)\n"
  },
  {
    "path": "scripts/pass_rate_filtering/launch_filtering.sh",
    "content": "\n\n# a bash foor loop from 0 to 17,400 in chunks of 200\n\nfor i in {0..17000..200}\ndo\n  START=$i\n  END=$((i + 200))\n  echo \"Processing chunk from $START to $END\"\n  \n  # Submit the job to SLURM\n  sbatch slurm/compute_pass_rate.slurm recipes/dataset_filtering/filter_dapo.yaml $START $END\ndone\n\nsbatch slurm/compute_pass_rate.slurm recipes/dataset_filtering/filter_dapo.yaml 17200 17398\n"
  },
  {
    "path": "scripts/run_benchmarks.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass, field\nfrom typing import List, Optional\n\nfrom open_r1.utils.evaluation import SUPPORTED_BENCHMARKS, run_benchmark_jobs\nfrom open_r1.configs import SFTConfig\nfrom trl import ModelConfig, TrlParser\n\n\n@dataclass\nclass ScriptArguments:\n    model_id: str = field(\n        default=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\",\n        metadata={\"help\": \"The Hub model id to push the model to.\"},\n    )\n    model_revision: str = field(default=\"main\", metadata={\"help\": \"The Hub model branch to push the model to.\"})\n    trust_remote_code: bool = field(default=False, metadata={\"help\": \"Trust the remote code.\"})\n    benchmarks: List[str] = field(\n        default_factory=lambda: [], metadata={\"help\": \"The benchmarks to run after training.\"}\n    )\n    list_benchmarks: bool = field(default=False, metadata={\"help\": \"List all supported benchmarks.\"})\n    system_prompt: Optional[str] = field(\n        default=None, metadata={\"help\": \"The system prompt to use for the benchmark.\"}\n    )\n\n\ndef main():\n    parser = TrlParser(ScriptArguments)\n    args = parser.parse_args_and_config()[0]\n    if args.list_benchmarks:\n        print(\"Supported benchmarks:\")\n        for benchmark in SUPPORTED_BENCHMARKS:\n            print(f\"  - {benchmark}\")\n        return\n    benchmark_args = SFTConfig(\n        output_dir=\"\",\n        hub_model_id=args.model_id,\n        hub_model_revision=args.model_revision,\n        benchmarks=args.benchmarks,\n        system_prompt=args.system_prompt,\n    )\n    run_benchmark_jobs(\n        benchmark_args,\n        ModelConfig(model_name_or_path=\"\", model_revision=\"\", trust_remote_code=args.trust_remote_code),\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/upload_details.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPush the details from a LightEval run to the Hub.\n\nUsage:\n\npython src/open_r1/utils/upload_details.py \\\n    --data_files {path_to_parquet_file} \\\n    --hub_repo_id {hub_repo_id} \\\n    --config_name {config_name}\n\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom typing import List\n\nfrom datasets import load_dataset\nfrom transformers import HfArgumentParser\n\n\n@dataclass\nclass ScriptArguments:\n    data_files: List[str] = field(default_factory=list)\n    hub_repo_id: str = None\n    config_name: str = None\n\n\ndef main():\n    parser = HfArgumentParser(ScriptArguments)\n    args = parser.parse_args_into_dataclasses()[0]\n\n    if all(file.endswith(\".json\") for file in args.data_files):\n        ds = load_dataset(\"json\", data_files=args.data_files)\n    elif all(file.endswith(\".jsonl\") for file in args.data_files):\n        ds = load_dataset(\"json\", data_files=args.data_files)\n    else:\n        ds = load_dataset(\"parquet\", data_files=args.data_files)\n    url = ds.push_to_hub(args.hub_repo_id, config_name=args.config_name, private=True)\n    print(f\"Dataset available at: {url}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "setup.cfg",
    "content": "[isort]\ndefault_section = FIRSTPARTY\nensure_newline_before_comments = True\nforce_grid_wrap = 0\ninclude_trailing_comma = True\nknown_first_party = open_r1\nknown_third_party =\n    transformers\n    datasets\n    fugashi\n    git\n    h5py\n    matplotlib\n    nltk\n    numpy\n    packaging\n    pandas\n    psutil\n    pytest\n    rouge_score\n    sacrebleu\n    seqeval\n    sklearn\n    streamlit\n    torch\n    tqdm\n\nline_length = 119\nlines_after_imports = 2\nmulti_line_output = 3\nuse_parentheses = True\n\n[flake8]\nignore = E203, E501, E741, W503, W605\nmax-line-length = 119\nper-file-ignores =\n    # imported but unused\n    __init__.py: F401\n\n[tool:pytest]\ndoctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
  },
  {
    "path": "setup.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py\n\n\nimport re\nimport shutil\nfrom pathlib import Path\n\nfrom setuptools import find_packages, setup\n\n\n# Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466\nstale_egg_info = Path(__file__).parent / \"open_r1.egg-info\"\nif stale_egg_info.exists():\n    print(\n        (\n            \"Warning: {} exists.\\n\\n\"\n            \"If you recently updated open_r1, this is expected,\\n\"\n            \"but it may prevent open_r1 from installing in editable mode.\\n\\n\"\n            \"This directory is automatically generated by Python's packaging tools.\\n\"\n            \"I will remove it now.\\n\\n\"\n            \"See https://github.com/pypa/pip/issues/5466 for details.\\n\"\n        ).format(stale_egg_info)\n    )\n    shutil.rmtree(stale_egg_info)\n\n\n# IMPORTANT: all dependencies should be listed here with their version requirements, if any.\n#   * If a dependency is fast-moving (e.g. trl), pin to the exact version\n_deps = [\n    \"accelerate==1.4.0\",\n    \"bitsandbytes>=0.43.0\",\n    \"datasets>=3.2.0\",\n    \"deepspeed==0.16.8\",\n    \"distilabel[vllm,ray,openai]>=1.5.2\",\n    \"e2b-code-interpreter>=1.0.5\",\n    \"einops>=0.8.0\",\n    \"flake8>=6.0.0\",\n    \"hf_transfer>=0.1.4\",\n    \"huggingface-hub[cli,hf_xet]>=0.30.2,<1.0\",\n    \"isort>=5.12.0\",\n    \"jieba\",  # Needed for Chinese language support\n    \"langdetect\",  # Needed for LightEval's extended tasks\n    \"latex2sympy2_extended>=1.0.6\",\n    \"liger-kernel>=0.5.10\",\n    \"lighteval @ git+https://github.com/huggingface/lighteval.git@d3da6b9bbf38104c8b5e1acc86f83541f9a502d1\",  # Critical bug fix for tokenizer revisions: https://github.com/huggingface/lighteval/pull/721\n    \"math-verify==0.5.2\",  # Used for math verification in grpo\n    \"morphcloud==0.1.67\",\n    \"packaging>=23.0\",\n    \"parameterized>=0.9.0\",\n    \"peft>=0.14.0\",\n    \"pytest\",\n    \"python-dotenv\",\n    \"ruff>=0.9.0\",\n    \"safetensors>=0.3.3\",\n    \"sentencepiece>=0.1.99\",\n    \"torch==2.6.0\",\n    \"transformers==4.52.3\",\n    \"trl[vllm]==0.18.0\",\n    \"wandb>=0.19.1\",\n    \"async-lru>=2.0.5\",\n    \"aiofiles>=24.1.0\",\n    \"pandas>=2.2.3\",\n]\n\n# this is a lookup table with items like:\n#\n# tokenizers: \"tokenizers==0.9.4\"\n# packaging: \"packaging\"\n#\n# some of the values are versioned whereas others aren't.\ndeps = {b: a for a, b in (re.findall(r\"^(([^!=<>~ \\[\\]]+)(?:\\[[^\\]]+\\])?(?:[!=<>~ ].*)?$)\", x)[0] for x in _deps)}\n\n\ndef deps_list(*pkgs):\n    return [deps[pkg] for pkg in pkgs]\n\n\nextras = {}\nextras[\"tests\"] = deps_list(\"pytest\", \"parameterized\", \"math-verify\", \"jieba\")\nextras[\"torch\"] = deps_list(\"torch\")\nextras[\"quality\"] = deps_list(\"ruff\", \"isort\", \"flake8\")\nextras[\"code\"] = deps_list(\"e2b-code-interpreter\", \"python-dotenv\", \"morphcloud\", \"jieba\", \"pandas\", \"aiofiles\")\nextras[\"eval\"] = deps_list(\"lighteval\", \"math-verify\")\nextras[\"dev\"] = extras[\"quality\"] + extras[\"tests\"] + extras[\"eval\"] + extras[\"code\"]\n\n# core dependencies shared across the whole project - keep this to a bare minimum :)\ninstall_requires = [\n    deps[\"accelerate\"],\n    deps[\"bitsandbytes\"],\n    deps[\"einops\"],\n    deps[\"datasets\"],\n    deps[\"deepspeed\"],\n    deps[\"hf_transfer\"],\n    deps[\"huggingface-hub\"],\n    deps[\"langdetect\"],\n    deps[\"latex2sympy2_extended\"],\n    deps[\"math-verify\"],\n    deps[\"liger-kernel\"],\n    deps[\"packaging\"],  # utilities from PyPA to e.g., compare versions\n    deps[\"safetensors\"],\n    deps[\"sentencepiece\"],\n    deps[\"transformers\"],\n    deps[\"trl\"],\n    deps[\"wandb\"],\n    deps[\"async-lru\"],\n]\n\nsetup(\n    name=\"open-r1\",\n    version=\"0.1.0.dev0\",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)\n    author=\"The Hugging Face team (past and future)\",\n    author_email=\"lewis@huggingface.co\",\n    description=\"Open R1\",\n    long_description=open(\"README.md\", \"r\", encoding=\"utf-8\").read(),\n    long_description_content_type=\"text/markdown\",\n    keywords=\"llm inference-time compute reasoning\",\n    license=\"Apache\",\n    url=\"https://github.com/huggingface/open-r1\",\n    package_dir={\"\": \"src\"},\n    packages=find_packages(\"src\"),\n    zip_safe=False,\n    extras_require=extras,\n    python_requires=\">=3.10.9\",\n    install_requires=install_requires,\n    classifiers=[\n        \"Development Status :: 3 - Alpha\",\n        \"Intended Audience :: Developers\",\n        \"Intended Audience :: Education\",\n        \"Intended Audience :: Science/Research\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Operating System :: OS Independent\",\n        \"Programming Language :: Python :: 3\",\n        \"Programming Language :: Python :: 3.10\",\n        \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    ],\n)\n"
  },
  {
    "path": "slurm/README.md",
    "content": "## Serving DeepSeek-R1 on 2x8 H100 SLURM nodes with SGLang \n\n1. Set up the environment (adjust for your cuda version):\n```bash\nconda create -n sglang124 python=3.11\nconda activate sglang124\n\npip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124\n\npip install sgl-kernel --force-reinstall --no-deps\npip install \"sglang[all]>=0.4.2.post4\" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/\n```\n\n2. Run the server and wait for the model to load:\n```bash\nsbatch slurm/serve_r1.slurm -m \"/fsx/deepseek-r1-checkpoint\" -e \"sglang124\"\n```\n\n3. Run the data generation script:\n```bash\npython scripts/generate_reasoning.py \\\n    --dataset-name \"AI-MO/NuminaMath-1.5\" \\\n    --output-file \"numinamath_r1_generations.jsonl\" \\\n    --prompt-column \"problem\" \\\n    --uuid-column \"problem\" \\\n    --api-addr \"<SGLANG_SERVER_ADDRESS>:39877\" \\\n    --num-generations 2 \\\n    --max-tokens 16384 \\\n    --max-concurrent 200\n```"
  },
  {
    "path": "slurm/compute_pass_rate.slurm",
    "content": "#!/bin/bash\n\n#SBATCH --job-name=open-r1-compute-pass-rate\n#SBATCH --partition=hopper-prod\n#SBATCH --qos=normal\n#SBATCH --nodes=1\n#SBATCH --gpus-per-node=1\n#SBATCH --output=./logs/%x-%j.out\n#SBATCH --error=./logs/%x-%j.err\n#SBATCH --time=01-00:00:00\n#SBATCH --requeue\n\n# example usage: sbatch slurm/dataset_filter.slurm recipes/dataset_filtering/filter_dapo.yaml 0 500\n\nset -x -e\n\nsource ~/.bashrc\nsource openr1/bin/activate\n\npython scripts/pass_rate_filtering/compute_pass_rate.py --config $1 --dataset_start_index $2 --dataset_end_index $3"
  },
  {
    "path": "slurm/e2b_router.slurm",
    "content": "#!/bin/bash\n\n#SBATCH --partition=hopper-cpu\n#SBATCH --mem=16g\n#SBATCH --cpus-per-task=16\n#SBATCH --output=/fsx/open-r1/logs/e2b_router/%x-%j.out\n#SBATCH --error=/fsx/open-r1/logs/e2b_router/%x-%j.err\n#SBATCH --requeue\n#SBATCH --time=7-00:00:00\n\necho \"Starting job\"\nset -x -e\n\nsource ~/.bashrc\nsource openr1/bin/activate\n\nsrun python scripts/e2b_router.py"
  },
  {
    "path": "slurm/evaluate.slurm",
    "content": "#!/bin/bash\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:8\n#SBATCH --partition=hopper-prod\n#SBATCH --output=./logs/%x-%j.out\n#SBATCH --error=./logs/%x-%j.err\n#SBATCH --requeue\n#SBATCH --time=1-00:00:00\n\n\n# Specific configuration optimized for the Hugging Face Compute Cluster\n# Be ye warned this may not work on other clusters!\nmodule load cuda/12.4\n\n# Refresh Weka on h4 cache\necho \"Refreshing Weka filesystem...\"\nfind -L /fsx/h4/ -type f | xargs -d '\\n' -r -n512 -P64 weka fs tier fetch\n\n# Needed for vLLM\nexport VLLM_WORKER_MULTIPROC_METHOD=spawn\n\nset -x -e\n\nsource ~/.bashrc\nsource openr1/bin/activate\n\nTASK_NAME=$1\nTASKS=$2\nMODEL_ID=$3\nMODEL_REVISION=$4\n# Optional args\n[ -z \"$5\"] && TENSOR_PARALLEL=False || TENSOR_PARALLEL=$5\n[ -z \"$6\"] && TRUST_REMOTE_CODE=False || TRUST_REMOTE_CODE=$6\n# $7 is reserved for system_prompt, see line 51\nNUM_GPUS=$(nvidia-smi -L | wc -l)\n\n# Use TP to shard model across GPUs\nif [ \"$TENSOR_PARALLEL\" = \"True\" ]; then\n    MODEL_ARGS=\"model_name=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\"\nelse\n    MODEL_ARGS=\"model_name=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}\"\nfi\n\nLM_EVAL_REPO_ID=\"open-r1/open-r1-eval-leaderboard\"\nMODEL_NAME=$(echo $MODEL_ID | sed 's/\\//_/g') # replaces / with _\nDETAILS_REPO_ID=\"open-r1/details-$MODEL_NAME\"\nOUTPUT_DIR=\"eval_results/$MODEL_ID/$MODEL_REVISION/$TASK_NAME\"\n# We need this flag since we run this script from training jobs that use DeepSpeed and the env vars get progated which causes errors during evaluation\nACCELERATE_USE_DEEPSPEED=false\n\necho \"Running lighteval script ...\"\necho \"Eval results will be saved to $OUTPUT_DIR\"\nlighteval vllm \"$MODEL_ARGS\" $TASKS \\\n    --use-chat-template \\\n    --output-dir $OUTPUT_DIR \\\n    --save-details \\\n    ${7:+--system-prompt \"$(echo \"$7\" | base64 --decode)\"}\n\nOUTPUT_FILEPATHS=$(find $OUTPUT_DIR/results/ -type f \\( -name \"*.json\" \\))\nfor filepath in $OUTPUT_FILEPATHS; do\n    echo \"Uploading $filepath to Hugging Face Hub...\"\n    filename=$(basename -- \"$filepath\")\n    for attempt in {1..20}; do\n        if huggingface-cli upload --repo-type space --private $LM_EVAL_REPO_ID $filepath $OUTPUT_DIR/$filename; then\n            echo \"Upload succeeded for $filepath\"\n            break\n        else\n            echo \"Upload failed for $filepath. Attempt $attempt of 20. Retrying in 5 seconds...\"\n            sleep 5\n        fi\n    done\ndone\n\necho \"Uploading details to Hugging Face Hub...\"\nDETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \\( -name \"*.parquet\" \\))\necho \"DETAILS_FILEPATHS: $DETAILS_FILEPATHS\"\nTIMESTAMP=$(date +\"%Y-%m-%dT%H-%M-%S\")\npython scripts/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP\n    \necho \"Cleaning up ...\"\nrm -rf $OUTPUT_DIR\n\necho \"Done!\"\n"
  },
  {
    "path": "slurm/experimental/serve_r1_vllm.slurm",
    "content": "#!/bin/bash\n#SBATCH --job-name=r1-vllm\n#SBATCH --partition=hopper-prod\n#SBATCH --qos=normal\n#SBATCH --nodes=4\n#SBATCH --gpus-per-node=8\n#SBATCH --exclusive\n#SBATCH --output=./logs/%x_%j_%n.out\n#SBATCH --error=./logs/%x_%j_%n.err\n#SBATCH --time=7-00:00:00\n#SBATCH --ntasks-per-node=1\n\nset -exuo pipefail\n\nMODEL_PATH=\"deepseek-ai/DeepSeek-R1\"\nCONDA_ENV=\"vllm7\"\nSERVER_PORT=8000\nRAY_PORT=6379\nRAY_DASHBOARD_PORT=8265\n\nwhile getopts \"m:e:h\" opt; do\n    case $opt in\n        m) MODEL_PATH=\"$OPTARG\" ;;\n        e) CONDA_ENV=\"$OPTARG\" ;;\n        h|?) echo \"Usage: sbatch $0 [-m MODEL_PATH] [-e CONDA_ENV]\"; exit 1 ;;\n    esac\ndone\n\n# Environment setup\nmodule load cuda/12.1\nsource ~/.bashrc\nsource \"$CONDA_PREFIX/etc/profile.d/conda.sh\"\nconda activate \"$CONDA_ENV\" || { echo \"Failed to activate conda env $CONDA_ENV\"; exit 1; }\n\n# Get nodes information\nNODES=($(scontrol show hostnames \"$SLURM_JOB_NODELIST\"))\nHEAD_NODE=\"${NODES[0]}\"\nHEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w \"$HEAD_NODE\" hostname --ip-address)\n\necho \"SLURM_JOB_ID: $SLURM_JOB_ID\"\necho \"SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST\"\necho \"Head node: $HEAD_NODE ($HEAD_NODE_IP)\"\n\n# Start Ray head node\necho \"Starting Ray head node at $HEAD_NODE\"\nsrun --nodes=1 --ntasks=1 -w \"$HEAD_NODE\" \\\n    ray start --head \\\n    --node-ip-address=\"$HEAD_NODE_IP\" \\\n    --port=$RAY_PORT \\\n    --dashboard-host=0.0.0.0 \\\n    --dashboard-port=$RAY_DASHBOARD_PORT \\\n    --block &\n\nsleep 10\n\n# Start Ray worker nodes\nWORKER_COUNT=$((SLURM_JOB_NUM_NODES - 1))\nfor ((i = 1; i <= WORKER_COUNT; i++)); do\n    WORKER_NODE=\"${NODES[$i]}\"\n    echo \"Starting Ray worker $i at $WORKER_NODE\"\n    srun --nodes=1 --ntasks=1 -w \"$WORKER_NODE\" \\\n        ray start --address \"$HEAD_NODE_IP:$RAY_PORT\" \\\n        --block &\n    sleep 5\ndone\n\necho \"Waiting for Ray cluster to initialize...\"\nsleep 60\n\n# Start vLLM server\necho \"Starting vLLM server...\"\nRAY_ADDRESS=\"http://$HEAD_NODE_IP:$RAY_DASHBOARD_PORT\" ray job submit \\\n    --working-dir src/open_r1 \\\n    --no-wait \\\n    --job-id vllm-server \\\n    -- vllm serve \"$MODEL_PATH\" \\\n        --tensor-parallel-size 8 \\\n        --pipeline-parallel-size 4 \\\n        --gpu-memory-utilization 0.90 \\\n        --max-model-len 32768 \\\n        --max-num-batched-tokens 262144 \\\n        --max-num-seqs 128 \\\n        --max-seq-len-to-capture 32768 \\\n        --enable-chunked-prefill true \\\n        --preemption-mode recompute \\\n        --swap-space 128 \\\n        --trust-remote-code \\\n        --distributed-executor-backend ray\n\n# Wait for server with timeout\nTIMEOUT=3600  # 1h\nSTART_TIME=$(date +%s)\necho \"Waiting for vLLM server (http://$HEAD_NODE_IP:$SERVER_PORT)...\"\n\nwhile true; do\n    if curl -s -o /dev/null -w \"%{http_code}\" \"http://$HEAD_NODE_IP:$SERVER_PORT/health\" >/dev/null 2>&1; then\n        echo \"Server is ready at http://$HEAD_NODE_IP:$SERVER_PORT\"\n        break\n    fi\n\n    CURRENT_TIME=$(date +%s)\n    if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then\n        echo \"Error: Server failed to start within $TIMEOUT seconds\"\n        exit 1\n    fi\n\n    echo \"Still waiting... ($(($CURRENT_TIME - $START_TIME)) seconds elapsed)\"\n    sleep 60\ndone\n\necho \"Checking available models...\"\ncurl \"http://$HEAD_NODE_IP:$SERVER_PORT/v1/models\"\nsleep 10\n\necho \"Executing sanity check...\"\ncurl \"http://$HEAD_NODE_IP:$SERVER_PORT/v1/completions\" \\\n    -H \"Content-Type: application/json\" \\\n    -d \"{\n        \\\"model\\\": \\\"default\\\",\n        \\\"prompt\\\": \\\"<｜begin▁of▁sentence｜><｜User｜>hi, how are you?<｜Assistant｜>\\\",\n        \\\"max_tokens\\\": 2048,\n        \\\"temperature\\\": 0.6\n    }\"\n\n# Keep the job running with health checks\nwhile true; do\n    if ! curl -s -o /dev/null \"http://$HEAD_NODE_IP:$SERVER_PORT/health\"; then\n        echo \"Error: Server health check failed\"\n        exit 1\n    fi\n    sleep 300\ndone"
  },
  {
    "path": "slurm/generate.slurm",
    "content": "#!/bin/bash\n#SBATCH --job-name=deepseek-r1-generation\n#SBATCH --partition=hopper-prod\n#SBATCH --qos=normal\n#SBATCH --nodes=2\n#SBATCH --exclusive\n#SBATCH --gpus-per-node=8\n#SBATCH --output=./logs/%x-%j.out\n#SBATCH --error=./logs/%x-%j.err\n#SBATCH --time=04-00:00:00\n\n# Parse command line arguments\nwhile [[ $# -gt 0 ]]; do\n    case $1 in\n        --hf-dataset)\n            HF_DATASET=\"$2\"\n            shift 2\n            ;;\n        --hf-dataset-config)\n            HF_DATASET_CONFIG=\"$2\"\n            shift 2\n            ;;\n        --hf-dataset-split)\n            HF_DATASET_SPLIT=\"$2\"\n            shift 2\n            ;;\n        --prompt-column)\n            PROMPT_COLUMN=\"$2\"\n            shift 2\n            ;;\n        --prompt-template)\n            PROMPT_TEMPLATE=\"$2\"\n            shift 2\n            ;;\n        --model)\n            MODEL=\"$2\"\n            shift 2\n            ;;\n        --temperature)\n            TEMPERATURE=\"$2\"\n            shift 2\n            ;;\n        --top-p)\n            TOP_P=\"$2\"\n            shift 2\n            ;;\n        --max-new-tokens)\n            MAX_NEW_TOKENS=\"$2\"\n            shift 2\n            ;;\n        --num-generations)\n            NUM_GENERATIONS=\"$2\"\n            shift 2\n            ;;\n        --input-batch-size)\n            INPUT_BATCH_SIZE=\"$2\"\n            shift 2\n            ;;\n        --client-replicas)\n            CLIENT_REPLICAS=\"$2\"\n            shift 2\n            ;;\n        --timeout)\n            TIMEOUT=\"$2\"\n            shift 2\n            ;;\n        --retries)\n            RETRIES=\"$2\"\n            shift 2\n            ;;\n        --hf-output-dataset)\n            HF_OUTPUT_DATASET=\"$2\"\n            shift 2\n            ;;\n        --private)\n            PRIVATE=\"true\"\n            shift\n            ;;\n        *)\n            echo \"Unknown parameter: $1\"\n            exit 1\n            ;;\n    esac\ndone\n\nif [ -z \"$MODEL\" ] || [ -z \"$HF_DATASET\" ]; then\n    echo \"Error: --model and --hf-dataset are required parameters\"\n    exit 1\nfi\n\n# Set default values for optional parameters\nHF_DATASET_SPLIT=${HF_DATASET_SPLIT:-\"train\"}\nPROMPT_COLUMN=${PROMPT_COLUMN:-\"prompt\"}\nPROMPT_TEMPLATE=${PROMPT_TEMPLATE:-\"{{ instruction }}\"}\nMAX_NEW_TOKENS=${MAX_NEW_TOKENS:-8192}\nNUM_GENERATIONS=${NUM_GENERATIONS:-1}\nINPUT_BATCH_SIZE=${INPUT_BATCH_SIZE:-64}\nCLIENT_REPLICAS=${CLIENT_REPLICAS:-1}\nTIMEOUT=${TIMEOUT:-900}\nRETRIES=${RETRIES:-0}\nPRIVATE=${PRIVATE:-\"false\"}\n\n# Print all input arguments\necho \"Input arguments:\"\necho \"MODEL: $MODEL\"\necho \"HF_DATASET: $HF_DATASET\"\necho \"HF_DATASET_CONFIG: $HF_DATASET_CONFIG\"\necho \"HF_DATASET_SPLIT: $HF_DATASET_SPLIT\"\necho \"PROMPT_COLUMN: $PROMPT_COLUMN\"\necho \"PROMPT_TEMPLATE: $PROMPT_TEMPLATE\"\necho \"TEMPERATURE: $TEMPERATURE\"\necho \"TOP_P: $TOP_P\"\necho \"MAX_NEW_TOKENS: $MAX_NEW_TOKENS\"\necho \"NUM_GENERATIONS: $NUM_GENERATIONS\"\necho \"INPUT_BATCH_SIZE: $INPUT_BATCH_SIZE\"\necho \"CLIENT_REPLICAS: $CLIENT_REPLICAS\"\necho \"TIMEOUT: $TIMEOUT\"\necho \"RETRIES: $RETRIES\"\necho \"HF_OUTPUT_DATASET: $HF_OUTPUT_DATASET\"\necho \"PRIVATE: $PRIVATE\"\necho \"-------------------\"\n\nset -ex\n\nmodule load cuda/12.4\n\nexport LD_LIBRARY_PATH=.venv/lib/python3.11/site-packages/nvidia/nvjitlink/lib\n\necho \"SLURM_JOB_ID: $SLURM_JOB_ID\"\necho \"SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST\"\n\nsource openr1/bin/activate\n\n# Getting the node names\nnodes=$(scontrol show hostnames \"$SLURM_JOB_NODELIST\")\nnodes_array=($nodes)\n\n# Get the IP address of the head node\nhead_node=${nodes_array[0]}\nhead_node_ip=$(srun --nodes=1 --ntasks=1 -w \"$head_node\" hostname --ip-address)\n\n# Start Ray head node\nport=6379\nip_head=$head_node_ip:$port\nexport ip_head\necho \"IP Head: $ip_head\"\n\necho \"Starting HEAD at $head_node\"\nsrun --nodes=1 --ntasks=1 -w \"$head_node\" \\\n    ray start --head --node-ip-address=\"$head_node_ip\" --port=$port \\\n    --dashboard-host=0.0.0.0 \\\n    --dashboard-port=8265 \\\n    --block &\n\n# Give some time to head node to start...\nsleep 10\n\n# Start Ray worker nodes\nworker_num=$((SLURM_JOB_NUM_NODES - 1))\n\n# Start from 1 (0 is head node)\nfor ((i = 1; i <= worker_num; i++)); do\n    node_i=${nodes_array[$i]}\n    echo \"Starting WORKER $i at $node_i\"\n    srun --nodes=1 --ntasks=1 -w \"$node_i\" \\\n        ray start --address \"$ip_head\" \\\n        --block &\n    sleep 5\ndone\n\n# Give some time to the Ray cluster to gather info\necho \"Waiting a bit for Ray cluster to gather node info...\"\nsleep 60\n\n# Run vllm\nRAY_ADDRESS=\"http://$head_node_ip:8265\" ray job submit \\\n    --working-dir src/open_r1 \\\n    --no-wait \\\n    --job-id vllm-server \\\n    -- vllm serve $MODEL \\\n    --tensor-parallel-size $SLURM_GPUS_PER_NODE \\\n    --pipeline-parallel-size $SLURM_JOB_NUM_NODES \\\n    --gpu-memory-utilization=0.85 \\\n    --max-model-len 16384 \\\n    --enable-chunked-prefill \\\n    --trust-remote-code \\\n    --distributed-executor-backend ray\n\n# wait for vllm to load the model\necho \"Waiting for vLLM (http://$head_node_ip:8000) server to be up...\"\n\n# wait for vllm to load and serve the model\nwhile true; do\n    if curl -s -o /dev/null -w \"%{http_code}\" http://$head_node_ip:8000 >/dev/null 2>&1; then\n        echo \"Received response from http://$head_node_ip:8000\"\n        break\n    else\n        echo \"Still waiting... (Press Ctrl+C to cancel)\"\n        sleep 60\n    fi\ndone\n\necho \"Checking available models...\"\ncurl http://$head_node_ip:8000/v1/models\n\necho \"Executing sanity check...\"\ncurl http://$head_node_ip:8000/v1/completions \\\n    -H \"Content-Type: application/json\" \\\n    -d \"{\n        \\\"model\\\": \\\"$MODEL\\\",\n        \\\"prompt\\\": \\\"<｜begin▁of▁sentence｜><｜User｜>hi, how are you?<｜Assistant｜>\\\",\n        \\\"max_tokens\\\": 2048,\n        \\\"temperature\\\": 0.6\n    }\"\n\n# Finally submit the job to the cluster\necho \"Submitting job to ray cluster...\"\nRAY_ADDRESS=\"http://$head_node_ip:8265\" ray job submit \\\n    --working-dir src/open_r1 \\\n    --job-id generate \\\n    -- python -u generate.py \\\n    --model \"$MODEL\" \\\n    --hf-dataset \"$HF_DATASET\" \\\n    ${HF_DATASET_CONFIG:+--hf-dataset-config \"$HF_DATASET_CONFIG\"} \\\n    --hf-dataset-split \"$HF_DATASET_SPLIT\" \\\n    --prompt-column \"$PROMPT_COLUMN\" \\\n    --prompt-template \"$PROMPT_TEMPLATE\" \\\n    ${TEMPERATURE:+--temperature \"$TEMPERATURE\"} \\\n    ${TOP_P:+--top-p \"$TOP_P\"} \\\n    --max-new-tokens \"$MAX_NEW_TOKENS\" \\\n    --num-generations \"$NUM_GENERATIONS\" \\\n    --input-batch-size \"$INPUT_BATCH_SIZE\" \\\n    --client-replicas \"$CLIENT_REPLICAS\" \\\n    --timeout \"$TIMEOUT\" \\\n    --retries \"$RETRIES\" \\\n    ${HF_OUTPUT_DATASET:+--hf-output-dataset \"$HF_OUTPUT_DATASET\"} \\\n    ${PRIVATE:+--private} \\\n    --vllm-server-url \"http://$head_node_ip:8000/v1\"\n\nmkdir -p ray_logs\n\necho \"Downloading Ray job logs...\"\nRAY_ADDRESS=\"http://$head_node_ip:8265\" ray job logs --job-id vllm-server > ray_logs/vllm-server-${SLURM_JOB_ID}.log\nRAY_ADDRESS=\"http://$head_node_ip:8265\" ray job logs --job-id generate > ray_logs/generate-${SLURM_JOB_ID}.log"
  },
  {
    "path": "slurm/morph_router.slurm",
    "content": "#!/bin/bash\n\n#SBATCH --partition=hopper-cpu\n#SBATCH --mem=16g\n#SBATCH --cpus-per-task=16\n#SBATCH --output=/fsx/open-r1/logs/morph_router/%x-%j.out\n#SBATCH --err=/fsx/open-r1/logs/morph_router/%x-%j.err\n#SBATCH --requeue\n#SBATCH --time=7-00:00:00\n\n\necho \"Starting job\"\nset -x -e\n\nsource ~/.bashrc\nsource openr1/bin/activate\n\nsrun python scripts/morph_router.py --port 8001 --max_num_sandboxes 20\n"
  },
  {
    "path": "slurm/piston/README.md",
    "content": "# Piston workers (slurm)\n\nWe have built a [piston](https://github.com/engineer-man/piston) package to run IOI problems.\n\nTo launch a fleet of piston workers on a slurm cluster, you can adapt the paths in `launch_piston_workers.sh` and `launch_single_piston.sh` and run:\n```bash\nslurm/piston/launch_piston_workers.sh (number of workers to launch)\n```\n\nThis command will launch a slurm job for each worker, which will be called `piston-worker-<port>`, where `<port>` is the port where the worker will be listening.\n\n## First time setup\nYou will need to install the [IOI package](https://github.com/guipenedo/piston/tree/master/packages/cms_ioi/1.0.0) in the workers.\n1. Launch a single worker:\n```bash\nslurm/piston/launch_piston_workers.sh 1\n```\n\n2. Assuming it's running on `ip-10-53-86-146:1234`, send the package install request:\n\nFor IOI:\n```bash\ncurl -X POST http://ip-10-53-86-146:1234/api/v2/packages -H \"Content-Type: application/json\" -d '{\"language\": \"cms_ioi\", \"version\": \"1.0.0\"}'\n```\n\nFor CodeForces:\n```bash\ncurl -X POST http://ip-10-53-86-146:1234/api/v2/packages -H \"Content-Type: application/json\" -d '{\"language\": \"codeforces\", \"version\": \"1.0.0\"}'\n```\n\n3. You can now launch more workers and due to the shared mounted packages directory, they should already have the package installed.\n\nTo have the main script find the workers automatically, you can export the following environment variable:\n```bash\nexport PISTON_ENDPOINTS=slurm\n```\nAlternatively your can add `PISTON_ENDPOINTS=slurm` to your .env file.\n\nYou can also change `PISTON_MAX_REQUESTS_PER_ENDPOINT`, which tries to limit how many simultaneous requests each worker will handle (1 by default). Keep in mind that this is a local limit and in distributed setups, as there is no global limit, workers might sometimes be overwhelmed when some processes hit the same worker.\n\nIf you would like to adapt the code to run without piston, please see the [ioi repo](https://github.com/huggingface/ioi).\nFor CodeForces, you should implement the [`run`](https://github.com/guipenedo/piston/blob/master/packages/codeforces/1.0.0/run) and [`compile`](https://github.com/guipenedo/piston/blob/master/packages/codeforces/1.0.0/compile) scripts.\n\n# Piston workers (local docker)\nThis will launch a single worker in a docker container. Consider launching multiple workers for better scalability. Replace 2000 with the port you want to use.\nMake sure to change `/path/to/local/packages` to the path you want to persist for package installs.\n\n```bash\ndocker run -d \\\n  --name piston_worker \\\n  -v /path/to/local/packages:/piston/packages \\\n  -e PORT=2000 \\\n  -e PISTON_COMPILE_TIMEOUT=60000 \\\n  -e PISTON_RUN_TIMEOUT=60000 \\\n  -e PISTON_OUTPUT_MAX_SIZE=1000000000 \\\n  -e PISTON_MAX_FILE_SIZE=1000000000 \\\n  -e PISTON_DISABLE_NETWORKING=true \\\n  -e PISTON_REPO_URL=https://github.com/guipenedo/piston/releases/download/pkgs/index \\\n  -p 2000:2000 \\\n  --entrypoint /bin/bash \\\n  ghcr.io/engineer-man/piston@sha256:63b5654156a89c5a2ad281aface21416615d62ec056d88efe8fcd307ce73575a \\\n  -c \"sed -i '/app.use(body_parser.urlencoded/c\\    app.use(body_parser.urlencoded({ extended: true, limit: \\\"512mb\\\" }));' src/index.js && \\\n      sed -i '/app.use(body_parser.json/c\\    app.use(body_parser.json({ limit: \\\"512mb\\\" }));' src/index.js && \\\n      node src\"\n```\n\nInstall the package:\nFor IOI:\n```bash\ncurl -X POST http://localhost:2000/api/v2/packages -H \"Content-Type: application/json\" -d '{\"language\": \"cms_ioi\", \"version\": \"1.0.0\"}'\n```\n\nFor CodeForces:\n```bash\ncurl -X POST http://localhost:2000/api/v2/packages -H \"Content-Type: application/json\" -d '{\"language\": \"codeforces\", \"version\": \"1.0.0\"}'\n```\n\nRemember to set `PISTON_ENDPOINTS`:\n```bash\nexport PISTON_ENDPOINTS=http://localhost:2000/api/v2,http://localhost:2001/api/v2,http://localhost:2002/api/v2\n```\n"
  },
  {
    "path": "slurm/piston/launch_piston_workers.sh",
    "content": "#!/bin/bash\n\n# this simple script will launch a bunch of piston workers on the HF science cluster\n\nN_INSTANCES=${1:-5}  # Default to 5 instances\n\nfor i in $(seq 1 $N_INSTANCES); do\n    # Find random (hopefully) available port\n    PORT=$(comm -23 <(seq 2000 10000 | sort) <(ss -tan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n1)\n    \n    # the job name format is important for the code to then be able to get a list of workers. `piston-worker-<port>`\n    sbatch \\\n        --job-name=\"piston-worker-$PORT\" \\\n        --export=ALL,PORT=$PORT \\\n        slurm/piston/launch_single_piston.sh\ndone"
  },
  {
    "path": "slurm/piston/launch_single_piston.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=piston_worker\n#SBATCH --output=/fsx/open-r1/logs/piston/worker-logs/%x-%j.out\n#SBATCH --error=/fsx/open-r1/logs/piston/worker-logs/%x-%j.out  # Redirect error logs to .out\n#SBATCH --cpus-per-task=2\n#SBATCH --mem-per-cpu=1950M\n#SBATCH --partition=hopper-cpu\n#SBATCH --time=48:00:00\n\n# sometimes if a bunch of workers start at the same time pyxis dies\nsleep $(( RANDOM % 20 ))\n\n# mounting the packages folder lets us not have to manually install the package on each instance\n# we use 63b5654156a89c5a2ad281aface21416615d62ec056d88efe8fcd307ce73575a as the latest image requires isolate, which does not work on the HF science cluster (cgroups incompatibility)\n# feel free try with the latest image\n# the code you see below increases the very constrained piston default limits, and sets the repo url to the one hosting our IOI package\nsrun --container-mounts=/fsx/guilherme/ioi2024/piston_files/packages:/piston/packages --container-image \"ghcr.io#engineer-man/piston:sha256:63b5654156a89c5a2ad281aface21416615d62ec056d88efe8fcd307ce73575a\" \\\n    bash -c \"\n    export PISTON_COMPILE_TIMEOUT=60000\n    export PISTON_RUN_TIMEOUT=60000\n    export PISTON_OUTPUT_MAX_SIZE=1000000000\n    export PISTON_MAX_FILE_SIZE=1000000000\n    export PISTON_DISABLE_NETWORKING=true\n    export PISTON_REPO_URL=https://github.com/guipenedo/piston/releases/download/pkgs/index\n\n    sed -i '/app.use(body_parser.urlencoded/c\\    app.use(body_parser.urlencoded({ extended: true, limit: \\\"512mb\\\" }));' src/index.js\n    sed -i '/app.use(body_parser.json/c\\    app.use(body_parser.json({ limit: \\\"512mb\\\" }));' src/index.js\n\n    # Start server in background\n    node src\n    \"\n"
  },
  {
    "path": "slurm/serve_r1.slurm",
    "content": "#!/bin/bash\n#SBATCH --job-name=r1-server\n#SBATCH --partition=hopper-prod\n#SBATCH --qos=normal\n#SBATCH --nodes=2\n#SBATCH --gpus-per-node=8\n#SBATCH --exclusive\n#SBATCH --output=./logs/%x_%j_%n.out\n#SBATCH --error=./logs/%x_%j_%n.err\n#SBATCH --time=7-00:00:00\n#SBATCH --ntasks-per-node=1\n\nset -exuo pipefail\n\nMODEL_PATH=\"deepseek-ai/DeepSeek-R1\"\nCONDA_ENV=\"sglang124\"\nROUTER_ADDRESS=\"\"\nSERVER_PORT=39877\nDIST_PORT=45000\n\n# TODO: Adjust these variables to your cluster configuration\nexport OUTLINES_CACHE_DIR=/scratch/serve_r1/ocache/\nexport TRITON_HOME=/scratch/serve_r1/triton/\nexport GLOO_SOCKET_IFNAME=\"enp71s0\"\nexport NCCL_SOCKET_IFNAME=\"enp71s0\"\n\nwhile getopts \"m:e:r:h\" opt; do\n    case $opt in\n        m) MODEL_PATH=\"$OPTARG\" ;;\n        e) CONDA_ENV=\"$OPTARG\" ;;\n        r) ROUTER_ADDRESS=\"$OPTARG\" ;;\n        h|?) echo \"Usage: sbatch $0 [-m MODEL_PATH] [-e CONDA_ENV] [-r ROUTER_ADDRESS]\"; exit 1 ;;\n    esac\ndone\n\n# TODO: Environment setup, adjust to your cluster configuration\nmodule load cuda/12.4\nsource ~/.bashrc\nsource \"$CONDA_PREFIX/etc/profile.d/conda.sh\"\nconda activate \"$CONDA_ENV\" || { echo \"Failed to activate conda env $CONDA_ENV\"; exit 1; }\n\nFIRST_NODE=$(scontrol show hostnames \"$SLURM_JOB_NODELIST\" | head -n1)\nFIRST_NODE_IP=$(srun --nodes=1 --ntasks=1 -w \"$FIRST_NODE\" hostname --ip-address)\n\n# Launch servers synchronously across all nodes\n# (--max-running-requests=56 is rough estimate to avoid too many evicted/preempted 16k-long requests)\nsrun --nodes=2 --ntasks=2 --ntasks-per-node=1 \\\n    bash -c \"python -m sglang.launch_server \\\n        --model-path '$MODEL_PATH' \\\n        --tp 16 \\\n        --dist-init-addr '$FIRST_NODE_IP:$DIST_PORT' \\\n        --nnodes 2 \\\n        --node-rank \\$SLURM_PROCID \\\n        --port '$SERVER_PORT' \\\n        --host 0.0.0.0 \\\n        --trust-remote-code \\\n        --max-running-requests 56 \\\n        --context-length 32768\" &\n\n# Wait for server with timeout\nTIMEOUT=3600  # 1h, but model loading should take ~30min\nSTART_TIME=$(date +%s)\necho \"Waiting for SGLang server (http://$FIRST_NODE_IP:$SERVER_PORT)...\"\n\nwhile true; do\n    if curl -s -o /dev/null -w \"%{http_code}\" \"http://$FIRST_NODE_IP:$SERVER_PORT/health\" >/dev/null 2>&1; then\n        echo \"Server is ready at http://$FIRST_NODE_IP:$SERVER_PORT\"\n        break\n    fi\n\n    CURRENT_TIME=$(date +%s)\n    if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then\n        echo \"Error: Server failed to start within $TIMEOUT seconds\"\n        exit 1\n    fi\n\n    echo \"Still waiting... ($(($CURRENT_TIME - $START_TIME)) seconds elapsed)\"\n    sleep 60\ndone\n\n# Register with router only if address was provided\nif [ -n \"$ROUTER_ADDRESS\" ]; then\n    echo \"Registering with router at $ROUTER_ADDRESS...\"\n    curl -X POST \"http://$ROUTER_ADDRESS/add_worker?url=http://$FIRST_NODE_IP:$SERVER_PORT\" || true\n    sleep 10\nfi\n\necho \"Checking available models...\"\ncurl \"http://$FIRST_NODE_IP:$SERVER_PORT/v1/models\"\nsleep 10\n\necho \"Executing sanity check...\"\ncurl \"http://$FIRST_NODE_IP:$SERVER_PORT/v1/completions\" \\\n    -H \"Content-Type: application/json\" \\\n    -d \"{\n        \\\"model\\\": \\\"default\\\",\n        \\\"prompt\\\": \\\"<｜begin▁of▁sentence｜><｜User｜>hi, how are you?<｜Assistant｜>\\\",\n        \\\"max_tokens\\\": 2048,\n        \\\"temperature\\\": 0.6\n    }\"\n\n# Keep the job running with health checks\nwhile true; do\n    if ! curl -s -o /dev/null \"http://$FIRST_NODE_IP:$SERVER_PORT/health\"; then\n        echo \"Error: Server health check failed\"\n        exit 1\n    fi\n    sleep 300\ndone"
  },
  {
    "path": "slurm/serve_router.slurm",
    "content": "#!/bin/bash\n#SBATCH --job-name=r1-router\n#SBATCH --partition=hopper-cpu\n#SBATCH --qos=high\n#SBATCH --nodes=1\n#SBATCH --cpus-per-task=8\n#SBATCH --mem-per-cpu=1875m\n#SBATCH --output=./logs/%x_%j_%n.out\n#SBATCH --error=./logs/%x_%j_%n.err\n#SBATCH --time=30-00:00:00\n#SBATCH --requeue\n\nset -exuo pipefail\n\n# TODO: Adjust these variables to your cluster configuration\nCONDA_ENV=\"sglang124\"\nROUTER_PORT=39876\n\ntrap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1\n\nwhile getopts \"e:h\" opt; do\n    case $opt in\n        e) CONDA_ENV=\"$OPTARG\" ;;\n        h|?) echo \"Usage: sbatch $0 [-e CONDA_ENV]\"; exit 1 ;;\n    esac\ndone\n\n# TODO: Environment setup, adjust to your cluster configuration\nsource ~/.bashrc\nsource \"$CONDA_PREFIX/etc/profile.d/conda.sh\"\nconda activate \"$CONDA_ENV\" || { echo \"Failed to activate conda env $CONDA_ENV\"; exit 1; }\n\npython -m sglang_router.launch_router \\\n    --port \"$ROUTER_PORT\" \\\n    --host 0.0.0.0 \\\n    --worker-startup-timeout-secs 300\n\n# Keep the job running with health checks\nwhile true; do\n    if ! curl -s -o /dev/null \"http://localhost:$ROUTER_PORT/health\"; then\n        echo \"Error: Router health check failed\"\n        exit 1\n    fi\n    sleep 300\ndone"
  },
  {
    "path": "slurm/train.slurm",
    "content": "#!/bin/bash\n#SBATCH --job-name=open_r1\n#SBATCH --ntasks-per-node=1\n#SBATCH --exclusive\n#SBATCH --gres=gpu:8\n#SBATCH --partition=hopper-prod  # Adjust this for your cluster\n#SBATCH --output=./logs/%x-%j.out\n#SBATCH --error=./logs/%x-%j.err\n#SBATCH --requeue\n#SBATCH --time=3-00:00:00\n\n\nif [[ \"$*\" == *\"--help\"* ]]; then\n  echo \"Usage: sbatch slurm/train.slurm [options]\"\n  echo \"Options:\"\n  echo \"  --model MODEL            Model name\"\n  echo \"  --task TASK              Task name (e.g. sft, grpo)\"\n  echo \"  --config SUFFIX          Configuration suffix (e.g. demo, v00.00)\"\n  echo \"  --accelerator CONFIG     Accelerator configuration name (e.g. zero3)\"\n  echo \"  --dp N                   Data parallelism for vLLM server (default: 1)\"\n  echo \"  --tp N                   Tensor parallelism for vLLM server (default: 1)\"\n  echo \"  --args \\\"ARGS\\\"          Optional arguments to pass to the training script\"\n  exit 0\nfi\n\n# Specific configuration optimized for the Hugging Face Compute Cluster\nmodule load cuda/12.4\nset -x -e\n\nsource ~/.bashrc\nsource openr1/bin/activate\nSTART_TIME=$(date +%s)\necho \"START TIME: $(date)\"\n\n# Refresh Weka on h4 cache\necho \"Refreshing Weka filesystem...\"\nfind -L /fsx/h4/ -type f | xargs -d '\\n' -r -n512 -P64 weka fs tier fetch\n\n# Default values\nMODEL=\"\"\nTASK=\"\"\nCONFIG_SUFFIX=\"\"\nACCELERATOR=\"\"\nDP=1\nTP=1\nOPTIONAL_ARGS=\"\"\n\n# Parse command line arguments\nwhile [[ $# -gt 0 ]]; do\n  case $1 in\n    --model)\n      MODEL=\"$2\"\n      shift 2\n      ;;\n    --task)\n      TASK=\"$2\"\n      shift 2\n      ;;\n    --config)\n      CONFIG_SUFFIX=\"$2\"\n      shift 2\n      ;;\n    --accelerator)\n      ACCELERATOR=\"$2\"\n      shift 2\n      ;;\n    --dp)\n      DP=\"$2\"\n      shift 2\n      ;;\n    --tp)\n      TP=\"$2\"\n      shift 2\n      ;;\n    --args)\n      OPTIONAL_ARGS=\"$2\"\n      shift 2\n      ;;\n    *)\n      echo \"Unknown option: $1\"\n      echo \"Use --help for usage information\"\n      exit 1\n      ;;\n  esac\ndone\n\n# Validate required arguments\nif [[ -z \"$MODEL\" || -z \"$TASK\" || -z \"$CONFIG_SUFFIX\" || -z \"$ACCELERATOR\" ]]; then\n  echo \"Error: Missing required arguments\"\n  echo \"Run with --help for usage information\"\n  exit 1\nfi\n\n\nCONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml\nGRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')\n\n# Split the string into individual arguments\nIFS=' ' read -ra ARGS <<< \"$OPTIONAL_ARGS\"\n# Loop through the arguments and find the one with \"--gradient_accumulation_steps\"\nfor arg in \"${ARGS[@]}\"; do\n    if [[ \"$arg\" == \"--gradient_accumulation_steps=\"* ]]; then\n        # Extract the value after the equals sign\n        GRAD_ACC_STEPS=\"${arg#*=}\"\n        break  # Exit the loop once we find the desired argument\n    fi\ndone\n\necho \"Gradient accumulation steps: $GRAD_ACC_STEPS\"\n\nMODEL=$(grep 'model_name_or_path:' $CONFIG_FILE | awk '{print $2}')\nREVISION=$(grep 'model_revision:' $CONFIG_FILE | head -n 1 | awk '{print $2}')\n\n# Distributed configuration\nNUM_NODES=$SLURM_NNODES\nGPUS_PER_NODE=8\nWORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))\nNODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))\nMASTER_ADDR=${NODELIST[0]}  # First node for main process\nMASTER_PORT=6000\nTRAIN_NODES=(\"${NODELIST[@]}\")\n\nUSE_VLLM=\"false\"\nif [[ -f \"$CONFIG_FILE\" ]] && grep -qE '^\\s*use_vllm:\\s*true' \"$CONFIG_FILE\"; then\n    USE_VLLM=\"true\"\nfi\n# if using vllm\nif [[ \"$USE_VLLM\" == \"true\" ]]; then\n     TRAIN_NODES=(\"${NODELIST[@]:0:$((NUM_NODES - 1))}\")\n     VLLM_NODE=${NODELIST[-1]} # Last node\n     WORLD_SIZE=$((WORLD_SIZE - GPUS_PER_NODE))\n     NUM_NODES=$((NUM_NODES - 1))\n     srun --nodes=1 --ntasks=1 --nodelist=$VLLM_NODE trl vllm-serve --model $MODEL --revision $REVISION --tensor_parallel_size $TP --data_parallel_size $DP &\n\n     OPTIONAL_ARGS=\"$OPTIONAL_ARGS --vllm_server_host=$VLLM_NODE\"\nfi\n\n# force crashing on nccl issues like hanging broadcast\nexport NCCL_ASYNC_ERROR_HANDLING=1\n# export NCCL_DEBUG=INFO\n# export NCCL_DEBUG_SUBSYS=COLL\n# export NCCL_SOCKET_NTHREADS=1\n# export NCCL_NSOCKS_PERTHREAD=1\n# export CUDA_LAUNCH_BLOCKING=1\n\nexport CMD=\" \\\n    src/open_r1/$TASK.py --config $CONFIG_FILE $OPTIONAL_ARGS\n    \"\n\nexport LAUNCHER=\"ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \\\n    --config_file recipes/accelerate_configs/$ACCELERATOR.yaml  \\\n    --gradient_accumulation_steps $GRAD_ACC_STEPS \\\n    --num_machines $NUM_NODES \\\n    --num_processes $WORLD_SIZE \\\n    --main_process_ip $MASTER_ADDR \\\n    --main_process_port $MASTER_PORT \\\n    --machine_rank $SLURM_PROCID \\\n    --rdzv_backend=c10d \\\n    --max_restarts 1 \\\n    --tee 3 \\\n    \"\n# srun error handling:\n# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks\n# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code\nNODELIST=$(IFS=,; echo \"${TRAIN_NODES[*]}\")\n\nSRUN_ARGS=\" \\\n    --wait=60 \\\n    --kill-on-bad-exit=1 \\\n    --nodes=$NUM_NODES \\\n    --ntasks=$NUM_NODES \\\n    --nodelist=$NODELIST\n    \"\nsrun $SRUN_ARGS bash -c \"$LAUNCHER $CMD\" 2>&1\n\nEND_TIME=$(date +%s)\necho \"END TIME: $(date)\"\nELAPSED_SECONDS=$((END_TIME - START_TIME))\nHOURS=$((ELAPSED_SECONDS / 3600))\nMINUTES=$(( (ELAPSED_SECONDS % 3600) / 60 ))\nSECONDS=$((ELAPSED_SECONDS % 60))\necho \"TOTAL JOB TIME: ${HOURS}h ${MINUTES}m ${SECONDS}s (${ELAPSED_SECONDS} seconds)\"\n"
  },
  {
    "path": "src/open_r1/__init__.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "src/open_r1/configs.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Literal, Optional\n\nimport trl\n\n\n@dataclass\nclass DatasetConfig:\n    \"\"\"Configuration for a dataset in a mixture.\"\"\"\n\n    id: str\n    config: Optional[str] = None\n    split: str = \"train\"\n    columns: Optional[list[str]] = None\n    weight: Optional[float] = None\n\n\n@dataclass\nclass DatasetMixtureConfig:\n    \"\"\"Configuration for a mixture of datasets.\"\"\"\n\n    datasets: list[DatasetConfig]\n    seed: int = 0\n    test_split_size: Optional[float] = None\n\n\n@dataclass\nclass ScriptArguments(trl.ScriptArguments):\n    \"\"\"\n    Extended version of ScriptArguments with support for dataset mixtures.\n\n    Args:\n        dataset_mixture (`dict[str, Any]` or `None`, *optional*, defaults to `None`):\n            Configuration for creating dataset mixtures with advanced options.\n            Format:\n              dataset_mixture:\n                datasets:\n                  - id: dataset_id1\n                    config: config_name\n                    columns:\n                      - col1\n                      - col2\n                    weight: 0.5\n                  - id: dataset_id2\n                    config: config_name\n                    columns:\n                      - col1\n                      - col2\n                    weight: 0.5\n                seed: 42\n                test_split_size: 0.1\n    \"\"\"\n\n    # Override the dataset_name to make it optional\n    dataset_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"Dataset name. Can be omitted if using dataset_mixture.\"}\n    )\n    dataset_mixture: Optional[dict[str, Any]] = field(\n        default=None,\n        metadata={\"help\": \"Configuration for creating dataset mixtures with advanced options like shuffling.\"},\n    )\n\n    def __post_init__(self):\n        if self.dataset_name is None and self.dataset_mixture is None:\n            raise ValueError(\"Either `dataset_name` or `dataset_mixture` must be provided\")\n\n        if self.dataset_mixture is not None:\n            if not isinstance(self.dataset_mixture, dict) or \"datasets\" not in self.dataset_mixture:\n                raise ValueError(\n                    \"dataset_mixture must be a dictionary with a 'datasets' key. \"\n                    \"Expected format: {'datasets': [...], 'seed': int}\"\n                )\n\n            datasets_list = []\n            datasets_data = self.dataset_mixture.get(\"datasets\", [])\n\n            if isinstance(datasets_data, list):\n                for dataset_config in datasets_data:\n                    datasets_list.append(\n                        DatasetConfig(\n                            id=dataset_config.get(\"id\"),\n                            config=dataset_config.get(\"config\"),\n                            split=dataset_config.get(\"split\", \"train\"),\n                            columns=dataset_config.get(\"columns\"),\n                            weight=dataset_config.get(\"weight\", 1.0),\n                        )\n                    )\n            else:\n                raise ValueError(\"'datasets' must be a list of dataset configurations\")\n\n            self.dataset_mixture = DatasetMixtureConfig(\n                datasets=datasets_list,\n                seed=self.dataset_mixture.get(\"seed\", 0),\n                test_split_size=self.dataset_mixture.get(\"test_split_size\", None),\n            )\n\n            # Check that column names are consistent across all dataset configs\n            columns_sets = [set(dataset.columns) for dataset in datasets_list if dataset.columns is not None]\n            if columns_sets:\n                first_columns = columns_sets[0]\n                if not all(columns == first_columns for columns in columns_sets):\n                    raise ValueError(\n                        \"Column names must be consistent across all dataset configurations in a mixture. \"\n                        f\"Found different column sets: {[list(cols) for cols in columns_sets]}\"\n                    )\n\n\n# TODO: add the shared options with a mixin to reduce code duplication\n@dataclass\nclass GRPOConfig(trl.GRPOConfig):\n    \"\"\"\n    args for callbacks, benchmarks etc\n    \"\"\"\n\n    benchmarks: list[str] = field(\n        default_factory=lambda: [],\n        metadata={\"help\": \"The benchmarks to run after training.\"},\n    )\n    callbacks: list[str] = field(\n        default_factory=lambda: [],\n        metadata={\"help\": \"The callbacks to run during training.\"},\n    )\n    chat_template: Optional[str] = field(default=None, metadata={\"help\": \"The chat template to use.\"})\n    hub_model_revision: Optional[str] = field(\n        default=\"main\", metadata={\"help\": \"The Hub model branch to push the model to.\"}\n    )\n    num_completions_to_print: int = field(default=0, metadata={\"help\": \"Number of completions to print.\"})\n    overwrite_hub_revision: bool = field(default=False, metadata={\"help\": \"Whether to overwrite the Hub revision.\"})\n    push_to_hub_revision: bool = field(default=False, metadata={\"help\": \"Whether to push to a Hub revision/branch.\"})\n    system_prompt: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"The optional system prompt to use.\"},\n    )\n    wandb_log_unique_prompts: bool = field(\n        default=True,\n        metadata={\n            \"help\": (\"Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.\")\n        },\n    )\n    wandb_entity: Optional[str] = field(\n        default=None,\n        metadata={\"help\": (\"The entity to store runs under.\")},\n    )\n    wandb_project: Optional[str] = field(\n        default=None,\n        metadata={\"help\": (\"The project to store runs under.\")},\n    )\n    wandb_run_group: Optional[str] = field(\n        default=None,\n        metadata={\"help\": (\"The group to store runs under.\")},\n    )\n\n\n@dataclass\nclass SFTConfig(trl.SFTConfig):\n    \"\"\"\n    args for callbacks, benchmarks etc\n    \"\"\"\n\n    benchmarks: list[str] = field(\n        default_factory=lambda: [],\n        metadata={\"help\": \"The benchmarks to run after training.\"},\n    )\n    callbacks: list[str] = field(\n        default_factory=lambda: [],\n        metadata={\"help\": \"The callbacks to run during training.\"},\n    )\n    chat_template: Optional[str] = field(default=None, metadata={\"help\": \"The chat template to use.\"})\n    system_prompt: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"The optional system prompt to use for benchmarking.\"},\n    )\n    hub_model_revision: Optional[str] = field(\n        default=\"main\",\n        metadata={\"help\": \"The Hub model branch to push the model to.\"},\n    )\n    overwrite_hub_revision: bool = field(default=False, metadata={\"help\": \"Whether to overwrite the Hub revision.\"})\n    push_to_hub_revision: bool = field(default=False, metadata={\"help\": \"Whether to push to a Hub revision/branch.\"})\n    wandb_entity: Optional[str] = field(\n        default=None,\n        metadata={\"help\": (\"The entity to store runs under.\")},\n    )\n    wandb_project: Optional[str] = field(\n        default=None,\n        metadata={\"help\": (\"The project to store runs under.\")},\n    )\n    wandb_run_group: Optional[str] = field(\n        default=None,\n        metadata={\"help\": (\"The group to store runs under.\")},\n    )\n\n\n@dataclass\nclass GRPOScriptArguments(ScriptArguments):\n    \"\"\"\n    Script arguments for the GRPO training script.\n\n    Args:\n        reward_funcs (`list[str]`):\n            List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', 'tag_count', 'code', 'ioi_code', 'code_format', 'soft_overlong_punishment'.\n        cosine_min_value_wrong (`float`):\n            Minimum reward for cosine scaling for wrong answers.\n        cosine_max_value_wrong (`float`):\n            Maximum reward for cosine scaling for wrong answers.\n        cosine_min_value_correct (`float`):\n            Minimum reward for cosine scaling for correct answers.\n        cosine_max_value_correct (`float`):\n            Maximum reward for cosine scaling for correct answers.\n        cosine_max_len (`int`):\n            Maximum length for cosine scaling.\n        code_language (`str`):\n            Language for code format reward.\n        max_completion_len (`int`):\n            Maximum number of tokens in completion.\n        soft_punish_cache (`int`):\n            Minimum number of tokens in completion.\n    \"\"\"\n\n    reward_funcs: list[str] = field(\n        default_factory=lambda: [\"accuracy\", \"format\", \"tag_count\"],\n        metadata={\n            \"help\": \"List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'\"\n        },\n    )\n    cosine_min_value_wrong: float = field(\n        default=0.0,\n        metadata={\"help\": \"Minimum reward for wrong answers\"},\n    )\n    cosine_max_value_wrong: float = field(\n        default=-0.5,\n        metadata={\"help\": \"Maximum reward for wrong answers\"},\n    )\n    cosine_min_value_correct: float = field(\n        default=0.5,\n        metadata={\"help\": \"Minimum reward for correct answers\"},\n    )\n    cosine_max_value_correct: float = field(\n        default=1.0,\n        metadata={\"help\": \"Maximum reward for correct answers\"},\n    )\n    cosine_max_len: int = field(\n        default=1000,\n        metadata={\"help\": \"Maximum length for scaling\"},\n    )\n    repetition_n_grams: int = field(\n        default=3,\n        metadata={\"help\": \"Number of n-grams for repetition penalty reward\"},\n    )\n    repetition_max_penalty: float = field(\n        default=-1.0,\n        metadata={\"help\": \"Maximum (negative) penalty for for repetition penalty reward\"},\n    )\n    code_language: str = field(\n        default=\"python\",\n        # '(?:python|cpp)'\n        metadata={\n            \"help\": \"Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages\",\n            \"choices\": [\"python\", \"javascript\", \"r\", \"java\", \"bash\", \"cpp\"],\n        },\n    )\n    code_eval_test_batch_size: int = field(\n        default=1,\n        metadata={\n            \"help\": \"for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions\"\n        },\n    )\n    code_eval_scoring_mode: Literal[\"pass_fail\", \"partial\", \"weighted_sum\"] = field(\n        default=\"weighted_sum\",\n        metadata={\"help\": \"use fraction of passed test cases as reward. If false, use 0/1 scoring.\"},\n    )\n    parallel_code_exec_per_proc: int = field(\n        default=2,\n        metadata={\n            \"help\": \"Number of parallel E2B code executions per process. Default of 2 is suitable for the Free Hobby tier of E2B with 8 GPUs used for training.\"\n        },\n    )\n\n    dataset_prompt_column: str = field(\n        default=\"prompt\",\n        metadata={\"help\": \"Column to use as prompts for training.\"},\n    )\n\n    e2b_router_url: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"URL for the E2B router. See scripts/e2b_router.py\"},\n    )\n\n    morph_router_url: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"URL for the MorphCloud router. See scripts/morph_router.py\"},\n    )\n\n    code_provider: Optional[str] = field(\n        default=\"e2b\",\n        metadata={\n            \"help\": \"Provider for code execution. Options: 'e2b', 'local', 'morph'.\",\n            \"choices\": [\"e2b\", \"local\", \"morph\"],\n        },\n    )\n\n    ioi_provider: Optional[str] = field(\n        default=\"piston\",\n        metadata={\n            \"help\": \"Provider for IOI code execution. Options: 'piston', 'morph'.\",\n            \"choices\": [\"piston\", \"morph\"],\n        },\n    )\n\n    max_completion_len: int = field(\n        default=16384,\n        metadata={\"help\": \"Maximum number of characters in completion.\"},\n    )\n    soft_punish_cache: int = field(\n        default=4096,\n        metadata={\"help\": \"Minimum number of characters in completion.\"},\n    )\n"
  },
  {
    "path": "src/open_r1/generate.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nfrom distilabel.llms import OpenAILLM\nfrom distilabel.pipeline import Pipeline\nfrom distilabel.steps import StepResources\nfrom distilabel.steps.tasks import TextGeneration\n\n\ndef build_distilabel_pipeline(\n    model: str,\n    base_url: str = \"http://localhost:8000/v1\",\n    prompt_column: Optional[str] = None,\n    prompt_template: str = \"{{ instruction }}\",\n    temperature: Optional[float] = None,\n    top_p: Optional[float] = None,\n    max_new_tokens: int = 8192,\n    num_generations: int = 1,\n    input_batch_size: int = 64,\n    client_replicas: int = 1,\n    timeout: int = 900,\n    retries: int = 0,\n) -> Pipeline:\n    generation_kwargs = {\"max_new_tokens\": max_new_tokens}\n\n    if temperature is not None:\n        generation_kwargs[\"temperature\"] = temperature\n\n    if top_p is not None:\n        generation_kwargs[\"top_p\"] = top_p\n\n    with Pipeline().ray() as pipeline:\n        TextGeneration(\n            llm=OpenAILLM(\n                base_url=base_url,\n                api_key=\"something\",\n                model=model,\n                timeout=timeout,\n                max_retries=retries,\n                generation_kwargs=generation_kwargs,\n            ),\n            template=prompt_template,\n            input_mappings=({\"instruction\": prompt_column} if prompt_column is not None else {}),\n            input_batch_size=input_batch_size,\n            num_generations=num_generations,\n            group_generations=True,\n            resources=StepResources(replicas=client_replicas),\n        )\n\n    return pipeline\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    from datasets import load_dataset\n\n    parser = argparse.ArgumentParser(description=\"Run distilabel pipeline for generating responses with DeepSeek R1\")\n    parser.add_argument(\n        \"--hf-dataset\",\n        type=str,\n        required=True,\n        help=\"HuggingFace dataset to load\",\n    )\n    parser.add_argument(\n        \"--hf-dataset-config\",\n        type=str,\n        required=False,\n        help=\"Dataset config to use\",\n    )\n    parser.add_argument(\n        \"--hf-dataset-split\",\n        type=str,\n        default=\"train\",\n        help=\"Dataset split to use\",\n    )\n    parser.add_argument(\n        \"--prompt-column\",\n        type=str,\n        default=\"prompt\",\n    )\n    parser.add_argument(\n        \"--prompt-template\",\n        type=str,\n        default=\"{{ instruction }}\",\n        help=\"Template string for formatting prompts.\",\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        required=True,\n        help=\"Model name to use for generation\",\n    )\n    parser.add_argument(\n        \"--vllm-server-url\",\n        type=str,\n        default=\"http://localhost:8000/v1\",\n        help=\"URL of the vLLM server\",\n    )\n    parser.add_argument(\n        \"--temperature\",\n        type=float,\n        help=\"Temperature for generation\",\n    )\n    parser.add_argument(\n        \"--top-p\",\n        type=float,\n        help=\"Top-p value for generation\",\n    )\n    parser.add_argument(\n        \"--max-new-tokens\",\n        type=int,\n        default=8192,\n        help=\"Maximum number of new tokens to generate\",\n    )\n    parser.add_argument(\n        \"--num-generations\",\n        type=int,\n        default=1,\n        help=\"Number of generations per problem\",\n    )\n    parser.add_argument(\n        \"--input-batch-size\",\n        type=int,\n        default=64,\n        help=\"Batch size for input processing\",\n    )\n    parser.add_argument(\n        \"--client-replicas\",\n        type=int,\n        default=1,\n        help=\"Number of client replicas for parallel processing\",\n    )\n    parser.add_argument(\n        \"--timeout\",\n        type=int,\n        default=600,\n        help=\"Request timeout in seconds (default: 600)\",\n    )\n    parser.add_argument(\n        \"--retries\",\n        type=int,\n        default=0,\n        help=\"Number of retries for failed requests (default: 0)\",\n    )\n    parser.add_argument(\n        \"--hf-output-dataset\",\n        type=str,\n        required=False,\n        help=\"HuggingFace repo to push results to\",\n    )\n    parser.add_argument(\n        \"--private\",\n        action=\"store_true\",\n        help=\"Whether to make the output dataset private when pushing to HF Hub\",\n    )\n\n    args = parser.parse_args()\n\n    print(\"\\nRunning with arguments:\")\n    for arg, value in vars(args).items():\n        print(f\"  {arg}: {value}\")\n    print()\n\n    print(f\"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...\")\n    dataset = load_dataset(args.hf_dataset, args.hf_dataset_config, split=args.hf_dataset_split)\n    print(\"Dataset loaded!\")\n\n    pipeline = build_distilabel_pipeline(\n        model=args.model,\n        base_url=args.vllm_server_url,\n        prompt_template=args.prompt_template,\n        prompt_column=args.prompt_column,\n        temperature=args.temperature,\n        top_p=args.top_p,\n        max_new_tokens=args.max_new_tokens,\n        num_generations=args.num_generations,\n        input_batch_size=args.input_batch_size,\n        client_replicas=args.client_replicas,\n        timeout=args.timeout,\n        retries=args.retries,\n    )\n\n    print(\"Running generation pipeline...\")\n    distiset = pipeline.run(\n        dataset=dataset,\n        dataset_batch_size=args.input_batch_size * 1000,\n        use_cache=False,\n    )\n    print(\"Generation pipeline finished!\")\n\n    if args.hf_output_dataset:\n        print(f\"Pushing resulting dataset to '{args.hf_output_dataset}'...\")\n        distiset.push_to_hub(args.hf_output_dataset, private=args.private)\n        print(\"Dataset pushed!\")\n"
  },
  {
    "path": "src/open_r1/grpo.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\n\nimport datasets\nimport transformers\nfrom transformers import set_seed\nfrom transformers.trainer_utils import get_last_checkpoint\n\nfrom open_r1.configs import GRPOConfig, GRPOScriptArguments\nfrom open_r1.rewards import get_reward_funcs\nfrom open_r1.utils import get_dataset, get_model, get_tokenizer\nfrom open_r1.utils.callbacks import get_callbacks\nfrom open_r1.utils.wandb_logging import init_wandb_training\nfrom trl import GRPOTrainer, ModelConfig, TrlParser, get_peft_config\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef main(script_args, training_args, model_args):\n    # Set seed for reproducibility\n    set_seed(training_args.seed)\n\n    ###############\n    # Setup logging\n    ###############\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n        handlers=[logging.StreamHandler(sys.stdout)],\n    )\n    log_level = training_args.get_process_log_level()\n    logger.setLevel(log_level)\n    datasets.utils.logging.set_verbosity(log_level)\n    transformers.utils.logging.set_verbosity(log_level)\n    transformers.utils.logging.enable_default_handler()\n    transformers.utils.logging.enable_explicit_format()\n\n    # Log on each process a small summary\n    logger.warning(\n        f\"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\"\n        + f\" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}\"\n    )\n    logger.info(f\"Model parameters {model_args}\")\n    logger.info(f\"Script parameters {script_args}\")\n    logger.info(f\"Training parameters {training_args}\")\n\n    # Check for last checkpoint\n    last_checkpoint = None\n    if os.path.isdir(training_args.output_dir):\n        last_checkpoint = get_last_checkpoint(training_args.output_dir)\n    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:\n        logger.info(f\"Checkpoint detected, resuming training at {last_checkpoint=}.\")\n\n    if \"wandb\" in training_args.report_to:\n        init_wandb_training(training_args)\n\n    # Load the dataset\n    dataset = get_dataset(script_args)\n\n    ################\n    # Load tokenizer\n    ################\n    tokenizer = get_tokenizer(model_args, training_args)\n\n    ##############\n    # Load model #\n    ##############\n    logger.info(\"*** Loading model ***\")\n    model = get_model(model_args, training_args)\n\n    # Get reward functions from the registry\n    reward_funcs = get_reward_funcs(script_args)\n\n    # Format into conversation\n    def make_conversation(example, prompt_column: str = script_args.dataset_prompt_column):\n        prompt = []\n\n        if training_args.system_prompt is not None:\n            prompt.append({\"role\": \"system\", \"content\": training_args.system_prompt})\n\n        if prompt_column not in example:\n            raise ValueError(f\"Dataset Question Field Error: {prompt_column} is not supported.\")\n\n        prompt.append({\"role\": \"user\", \"content\": example[prompt_column]})\n        return {\"prompt\": prompt}\n\n    dataset = dataset.map(make_conversation)\n\n    for split in dataset:\n        if \"messages\" in dataset[split].column_names:\n            dataset[split] = dataset[split].remove_columns(\"messages\")\n\n    #############################\n    # Initialize the GRPO trainer\n    #############################\n    trainer = GRPOTrainer(\n        model=model,\n        reward_funcs=reward_funcs,\n        args=training_args,\n        train_dataset=dataset[script_args.dataset_train_split],\n        eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != \"no\" else None),\n        peft_config=get_peft_config(model_args),\n        callbacks=get_callbacks(training_args, model_args),\n        processing_class=tokenizer,\n    )\n\n    ###############\n    # Training loop\n    ###############\n    logger.info(\"*** Train ***\")\n    checkpoint = None\n    if training_args.resume_from_checkpoint is not None:\n        checkpoint = training_args.resume_from_checkpoint\n    elif last_checkpoint is not None:\n        checkpoint = last_checkpoint\n    train_result = trainer.train(resume_from_checkpoint=checkpoint)\n    metrics = train_result.metrics\n    metrics[\"train_samples\"] = len(dataset[script_args.dataset_train_split])\n    trainer.log_metrics(\"train\", metrics)\n    trainer.save_metrics(\"train\", metrics)\n    trainer.save_state()\n\n    ##################################\n    # Save model and create model card\n    ##################################\n    logger.info(\"*** Save model ***\")\n    # Align the model's generation config with the tokenizer's eos token\n    # to avoid unbounded generation in the transformers `pipeline()` function\n    trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id\n    trainer.save_model(training_args.output_dir)\n    logger.info(f\"Model saved to {training_args.output_dir}\")\n\n    # Save everything else on main process\n    kwargs = {\n        \"dataset_name\": script_args.dataset_name,\n        \"tags\": [\"open-r1\"],\n    }\n    if trainer.accelerator.is_main_process:\n        trainer.create_model_card(**kwargs)\n        # Restore k,v cache for fast inference\n        trainer.model.config.use_cache = True\n        trainer.model.config.save_pretrained(training_args.output_dir)\n\n    ##########\n    # Evaluate\n    ##########\n    if training_args.do_eval:\n        logger.info(\"*** Evaluate ***\")\n        metrics = trainer.evaluate()\n        metrics[\"eval_samples\"] = len(dataset[script_args.dataset_test_split])\n        trainer.log_metrics(\"eval\", metrics)\n        trainer.save_metrics(\"eval\", metrics)\n\n    #############\n    # push to hub\n    #############\n    if training_args.push_to_hub:\n        logger.info(\"Pushing to hub...\")\n        trainer.push_to_hub(**kwargs)\n\n\nif __name__ == \"__main__\":\n    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))\n    script_args, training_args, model_args = parser.parse_args_and_config()\n    main(script_args, training_args, model_args)\n"
  },
  {
    "path": "src/open_r1/rewards.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Reward functions for GRPO training.\"\"\"\n\nimport asyncio\nimport json\nimport math\nimport re\nfrom functools import partial, update_wrapper\nfrom typing import Callable, Dict, Literal, Optional\n\nfrom latex2sympy2_extended import NormalizationConfig\nfrom math_verify import LatexExtractionConfig, parse, verify\n\nfrom .utils.code_providers import get_provider\nfrom .utils.competitive_programming import (\n    SubtaskResult,\n    add_includes,\n    get_morph_client_from_env,\n    get_piston_client_from_env,\n)\nfrom .utils.competitive_programming import patch_code as cf_patch_code\nfrom .utils.competitive_programming import score_submission as cf_score_submission\nfrom .utils.competitive_programming import score_subtask\n\n\ndef accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:\n    \"\"\"Reward function that checks if the completion is the same as the ground truth.\"\"\"\n    contents = [completion[0][\"content\"] for completion in completions]\n    rewards = []\n    for content, sol in zip(contents, solution):\n        gold_parsed = parse(\n            sol,\n            extraction_mode=\"first_match\",\n        )\n        if len(gold_parsed) != 0:\n            # We require the answer to be provided in correct latex (no malformed operators)\n            answer_parsed = parse(\n                content,\n                extraction_config=[\n                    LatexExtractionConfig(\n                        normalization_config=NormalizationConfig(\n                            nits=False,\n                            malformed_operators=False,\n                            basic_latex=True,\n                            equations=True,\n                            boxed=\"all\",\n                            units=True,\n                        ),\n                        # Ensures that boxed is tried first\n                        boxed_match_priority=0,\n                        try_extract_without_anchor=False,\n                    )\n                ],\n                extraction_mode=\"first_match\",\n            )\n            # Compute binary rewards if verifiable, `None` otherwise to skip this example\n            try:\n                reward = float(verify(gold_parsed, answer_parsed))\n            except Exception as e:\n                print(f\"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}\")\n                reward = None\n        else:\n            # If the gold solution is not parseable, we assign `None` to skip this example\n            reward = None\n            print(\"Failed to parse gold solution: \", sol)\n        rewards.append(reward)\n\n    return rewards\n\n\ndef format_reward(completions, **kwargs):\n    \"\"\"Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags.\"\"\"\n    pattern = r\"^<think>\\n.*?\\n</think>\\n<answer>\\n.*?\\n</answer>$\"\n    completion_contents = [completion[0][\"content\"] for completion in completions]\n    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]\n    return [1.0 if match else 0.0 for match in matches]\n\n\ndef tag_count_reward(completions, **kwargs) -> list[float]:\n    \"\"\"Reward function that checks if we produce the desired number of think and answer tags associated with `format_reward()`.\n\n    Adapted from: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb#file-grpo_demo-py-L90\n    \"\"\"\n\n    def count_tags(text: str) -> float:\n        count = 0.0\n        if text.count(\"<think>\\n\") == 1:\n            count += 0.25\n        if text.count(\"\\n</think>\\n\") == 1:\n            count += 0.25\n        if text.count(\"\\n<answer>\\n\") == 1:\n            count += 0.25\n        if text.count(\"\\n</answer>\") == 1:\n            count += 0.25\n        return count\n\n    contents = [completion[0][\"content\"] for completion in completions]\n    return [count_tags(c) for c in contents]\n\n\ndef reasoning_steps_reward(completions, **kwargs):\n    r\"\"\"Reward function that checks for clear step-by-step reasoning.\n    Regex pattern:\n        Step \\d+: - matches \"Step 1:\", \"Step 2:\", etc.\n        ^\\d+\\. - matches numbered lists like \"1.\", \"2.\", etc. at start of line\n        \\n- - matches bullet points with hyphens\n        \\n\\* - matches bullet points with asterisks\n        First,|Second,|Next,|Finally, - matches transition words\n    \"\"\"\n    pattern = r\"(Step \\d+:|^\\d+\\.|\\n-|\\n\\*|First,|Second,|Next,|Finally,)\"\n    completion_contents = [completion[0][\"content\"] for completion in completions]\n    matches = [len(re.findall(pattern, content)) for content in completion_contents]\n\n    # Magic number 3 to encourage 3 steps and more, otherwise partial reward\n    return [min(1.0, count / 3) for count in matches]\n\n\ndef len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:\n    \"\"\"Compute length-based rewards to discourage overthinking and promote token efficiency.\n\n    Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599\n\n    Args:\n        completions: List of model completions\n        solution: List of ground truth solutions\n\n    Returns:\n        List of rewards where:\n        - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)\n        - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))\n    \"\"\"\n    contents = [completion[0][\"content\"] for completion in completions]\n\n    # First check correctness of answers\n    correctness = []\n    for content, sol in zip(contents, solution):\n        gold_parsed = parse(\n            sol,\n            extraction_mode=\"first_match\",\n            extraction_config=[LatexExtractionConfig()],\n        )\n        if len(gold_parsed) == 0:\n            # Skip unparseable examples\n            correctness.append(True)  # Treat as correct to avoid penalizing\n            print(\"Failed to parse gold solution: \", sol)\n            continue\n\n        answer_parsed = parse(\n            content,\n            extraction_config=[\n                LatexExtractionConfig(\n                    normalization_config=NormalizationConfig(\n                        nits=False,\n                        malformed_operators=False,\n                        basic_latex=True,\n                        equations=True,\n                        boxed=True,\n                        units=True,\n                    ),\n                    boxed_match_priority=0,\n                    try_extract_without_anchor=False,\n                )\n            ],\n            extraction_mode=\"first_match\",\n        )\n        correctness.append(verify(answer_parsed, gold_parsed))\n\n    # Calculate lengths\n    lengths = [len(content) for content in contents]\n    min_len = min(lengths)\n    max_len = max(lengths)\n\n    # If all responses have the same length, return zero rewards\n    if max_len == min_len:\n        return [0.0] * len(completions)\n\n    rewards = []\n    for length, is_correct in zip(lengths, correctness):\n        lambda_val = 0.5 - (length - min_len) / (max_len - min_len)\n\n        if is_correct:\n            reward = lambda_val\n        else:\n            reward = min(0, lambda_val)\n\n        rewards.append(float(reward))\n\n    return rewards\n\n\ndef get_cosine_scaled_reward(\n    min_value_wrong: float = -1.0,\n    max_value_wrong: float = -0.5,\n    min_value_correct: float = 0.5,\n    max_value_correct: float = 1.0,\n    max_len: int = 1000,\n):\n    def cosine_scaled_reward(completions, solution, **kwargs):\n        \"\"\"Reward function that scales based on completion length using a cosine schedule.\n\n        Shorter correct solutions are rewarded more than longer ones.\n        Longer incorrect solutions are penalized less than shorter ones.\n\n        Args:\n            completions: List of model completions\n            solution: List of ground truth solutions\n\n        This function is parameterized by the following arguments:\n            min_value_wrong: Minimum reward for wrong answers\n            max_value_wrong: Maximum reward for wrong answers\n            min_value_correct: Minimum reward for correct answers\n            max_value_correct: Maximum reward for correct answers\n            max_len: Maximum length for scaling\n        \"\"\"\n        contents = [completion[0][\"content\"] for completion in completions]\n        rewards = []\n\n        for content, sol in zip(contents, solution):\n            gold_parsed = parse(\n                sol,\n                extraction_mode=\"first_match\",\n                extraction_config=[LatexExtractionConfig()],\n            )\n            if len(gold_parsed) == 0:\n                rewards.append(1.0)  # Skip unparseable examples\n                print(\"Failed to parse gold solution: \", sol)\n                continue\n\n            answer_parsed = parse(\n                content,\n                extraction_config=[\n                    LatexExtractionConfig(\n                        normalization_config=NormalizationConfig(\n                            nits=False,\n                            malformed_operators=False,\n                            basic_latex=True,\n                            equations=True,\n                            boxed=True,\n                            units=True,\n                        ),\n                        boxed_match_priority=0,\n                        try_extract_without_anchor=False,\n                    )\n                ],\n                extraction_mode=\"first_match\",\n            )\n\n            is_correct = verify(answer_parsed, gold_parsed)\n            gen_len = len(content)\n\n            # Apply cosine scaling based on length\n            progress = gen_len / max_len\n            cosine = math.cos(progress * math.pi)\n\n            if is_correct:\n                min_value = min_value_correct\n                max_value = max_value_correct\n            else:\n                # Swap min/max for incorrect answers\n                min_value = max_value_wrong\n                max_value = min_value_wrong\n\n            reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)\n            rewards.append(float(reward))\n\n        return rewards\n\n    return cosine_scaled_reward\n\n\ndef get_repetition_penalty_reward(ngram_size: int, max_penalty: float, language: str = \"en\"):\n    \"\"\"\n    Computes N-gram repetition penalty as described in Appendix C.2 of https://huggingface.co/papers/2502.03373.\n    Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py\n\n    Args:\n    ngram_size: size of the n-grams\n    max_penalty: Maximum (negative) penalty for wrong answers\n    language: Language of the text, defaults to `en`. Used to choose the way to split the text into n-grams.\n    \"\"\"\n    if max_penalty > 0:\n        raise ValueError(f\"max_penalty {max_penalty} should not be positive\")\n\n    if language == \"en\":\n\n        def zipngram(text: str, ngram_size: int):\n            words = text.lower().split()\n            return zip(*[words[i:] for i in range(ngram_size)]), words\n\n    elif language == \"zh\":\n        from transformers.utils.import_utils import _is_package_available\n\n        if not _is_package_available(\"jieba\"):\n            raise ValueError(\"Please install jieba to use Chinese language\")\n\n        def zipngram(text: str, ngram_size: int):\n            import jieba\n\n            seg_list = list(jieba.cut(text))\n            return zip(*[seg_list[i:] for i in range(ngram_size)]), seg_list\n\n    else:\n        raise ValueError(\n            f\"Word splitting for language `{language}` is not yet implemented. Please implement your own zip-ngram function.\"\n        )\n\n    def repetition_penalty_reward(completions, **kwargs) -> float:\n        \"\"\"\n        reward function the penalizes repetitions\n        ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py\n\n        Args:\n            completions: List of model completions\n        \"\"\"\n\n        contents = [completion[0][\"content\"] for completion in completions]\n        rewards = []\n        for completion in contents:\n            if completion == \"\":\n                rewards.append(0.0)\n                continue\n\n            ngrams = set()\n            total = 0\n            ngram_array, words = zipngram(completion, ngram_size)\n\n            if len(words) < ngram_size:\n                rewards.append(0.0)\n                continue\n\n            for ng in ngram_array:\n                ngrams.add(ng)\n                total += 1\n\n            scaling = 1 - len(ngrams) / total\n            reward = scaling * max_penalty\n            rewards.append(reward)\n        return rewards\n\n    return repetition_penalty_reward\n\n\ndef _init_event_loop():\n    \"\"\"Initialize or get the current event loop.\"\"\"\n    try:\n        loop = asyncio.get_event_loop()\n    except RuntimeError:\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n    return loop\n\n\ndef ioi_code_reward(completions, test_batch_size: int = 1, provider_type: str = \"piston\", **kwargs) -> list[float]:\n    \"\"\"Reward function that evaluates IOI problems using a specified execution client.\n\n    Assumes the dataset has the same format as hf.co/datasets/open-r1/ioi\n\n    Args:\n        completions: List of model completions to evaluate\n        test_batch_size: Evaluate these many test cases in parallel, then check if any of them failed (0 score):\n                       if so stop evaluating; otherwise continue with the next batch of test cases.\n        provider_type: The execution provider to use (default: \"piston\"). Supported values: \"piston\", \"morph\"\n        **kwargs: Additional arguments passed from the dataset\n    \"\"\"\n    # Get the appropriate client based on provider_type\n    if provider_type == \"morph\":\n        execution_client = get_morph_client_from_env()\n    else:\n        # for info on setting up piston workers, see slurm/piston/README.md\n        execution_client = get_piston_client_from_env()\n\n    code_snippets = [\n        # note: grading is automatically skipped if no code is extracted\n        add_includes(extract_code(completion[-1][\"content\"], \"cpp\"), problem_id)\n        for completion, problem_id in zip(completions, kwargs[\"id\"])\n    ]\n\n    async def run_catch_exceptions(task):\n        try:\n            return await task\n        except Exception as e:\n            print(f\"Error from {provider_type} worker: {e}\")\n            return SubtaskResult()\n\n    problems_data = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())]\n\n    loop = _init_event_loop()\n    evals = [\n        loop.create_task(\n            run_catch_exceptions(\n                score_subtask(\n                    execution_client,\n                    problem_data,\n                    code,\n                    test_batch_size=test_batch_size,\n                )\n            )\n        )\n        for problem_data, code in zip(problems_data, code_snippets)\n    ]\n    results = loop.run_until_complete(asyncio.gather(*evals))\n\n    return [result.score for result in results]\n\n\ndef cf_code_reward(\n    completions,\n    test_batch_size: int = 1,\n    patch_code: bool = False,\n    scoring_mode: Literal[\"pass_fail\", \"partial\", \"weighted_sum\"] = \"weighted_sum\",\n    **kwargs,\n) -> list[float]:\n    \"\"\"Reward function that evaluates Codeforces problems using Piston+our CF package.\n\n    Assumes the dataset has the same format as hf.co/datasets/open-r1/codeforces (verifiable-prompts subset)\n\n    test_batch_size: evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases.\n    \"\"\"\n    # for info on setting up piston workers, see slurm/piston/README.md\n    piston_client = get_piston_client_from_env()\n\n    languages = kwargs[\"language\"] if \"language\" in kwargs else [None] * len(completions)\n    code_snippets = [\n        # note: grading is automatically skipped if a problem has no tests\n        cf_patch_code(extract_code(completion[-1][\"content\"], language), language)\n        if patch_code\n        else extract_code(completion[-1][\"content\"], language)\n        for completion, language in zip(completions, languages)\n    ]\n\n    async def run_catch_exceptions(task):\n        try:\n            return await task\n        except Exception as e:\n            print(f\"Error from Piston worker: {e}\")\n            return None\n\n    # load problem data. undo separating kwargs by column\n    problems_data = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())]\n\n    loop = _init_event_loop()\n    evals = [\n        loop.create_task(\n            run_catch_exceptions(\n                cf_score_submission(\n                    piston_client,\n                    problem_data,\n                    code,\n                    test_batch_size=test_batch_size,\n                    scoring_mode=scoring_mode,\n                    submission_language=problem_data.get(\"language\", None),\n                )\n            )\n        )\n        for problem_data, code in zip(problems_data, code_snippets)\n    ]\n    results = loop.run_until_complete(asyncio.gather(*evals))\n\n    return results\n\n\ndef extract_code(completion: str, language: str | None = \"python\") -> str:\n    if language is None:\n        return \"\"\n    pattern = re.compile(rf\"```{language}\\n(.*?)```\", re.DOTALL)\n    matches = pattern.findall(completion)\n    extracted_answer = matches[-1] if len(matches) >= 1 else \"\"\n    return extracted_answer\n\n\ndef binary_code_reward(\n    completions,\n    num_parallel: int = 2,\n    provider_type: str = \"e2b\",\n    enforce_same_language: bool = False,\n    **kwargs,\n) -> list[float]:\n    rewards = code_reward(\n        completions,\n        num_parallel=num_parallel,\n        provider_type=provider_type,\n        enforce_same_language=enforce_same_language,\n        **kwargs,\n    )\n    BINARY_THRESHOLD = 0.99\n\n    output = []\n    for reward in rewards:\n        if reward is None:\n            output.append(None)\n        else:\n            output.append(1.0 if reward > BINARY_THRESHOLD else 0.0)\n\n    return output\n\n\ndef code_reward(\n    completions,\n    num_parallel: int = 2,\n    provider_type: str = \"e2b\",\n    enforce_same_language: bool = False,\n    **kwargs,\n) -> list[float]:\n    \"\"\"Reward function that evaluates code snippets using a code execution provider.\n\n    Assumes the dataset contains a `verification_info` column with test cases.\n\n    Args:\n        completions: List of model completions to evaluate\n        num_parallel: Number of parallel code executions (default: 2)\n        provider_type: Which code execution provider to use (default: \"e2b\")\n        enforce_same_language: If True, verify all problems use the same language (default: False)\n        **kwargs: Additional arguments passed to the verification\n    \"\"\"\n    evaluation_script_template = \"\"\"\n    import subprocess\n    import json\n\n    def evaluate_code(code, test_cases):\n        passed = 0\n        total = len(test_cases)\n        exec_timeout = 5\n\n        for case in test_cases:\n            process = subprocess.run(\n                [\"python3\", \"-c\", code],\n                input=case[\"input\"],\n                text=True,\n                capture_output=True,\n                timeout=exec_timeout\n            )\n\n            if process.returncode != 0:  # Error in execution\n                continue\n\n            output = process.stdout.strip()\n\n            # TODO: implement a proper validator to compare against ground truth. For now we just check for exact string match on each line of stdout.\n            all_correct = True\n            for line1, line2 in zip(output.split('\\\\n'), case['output'].split('\\\\n')):\n                all_correct = all_correct and line1.strip() == line2.strip()\n\n            if all_correct:\n                passed += 1\n\n        success_rate = (passed / total)\n        return success_rate\n\n    code_snippet = {code}\n    test_cases = json.loads({test_cases})\n\n    evaluate_code(code_snippet, test_cases)\n    \"\"\"\n\n    code_snippets = [extract_code(completion[-1][\"content\"]) for completion in completions]\n    verification_info = kwargs[\"verification_info\"]\n\n    template = evaluation_script_template\n\n    scripts = [\n        template.format(code=json.dumps(code), test_cases=json.dumps(json.dumps(info[\"test_cases\"])))\n        for code, info in zip(code_snippets, verification_info)\n    ]\n\n    language = verification_info[0][\"language\"]\n\n    if enforce_same_language:\n        all_same_language = all(v[\"language\"] == language for v in verification_info)\n        if not all_same_language:\n            raise ValueError(\"All verification_info must have the same language\", verification_info)\n\n    execution_provider = get_provider(\n        provider_type=provider_type,\n        num_parallel=num_parallel,\n        **kwargs,\n    )\n\n    return execution_provider.execute_scripts(scripts, [\"python\"] * len(scripts))\n\n\ndef get_code_format_reward(language: str = \"python\"):\n    \"\"\"Format reward function specifically for code responses.\n\n    Args:\n        language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages\n    \"\"\"\n\n    def code_format_reward(completions, **kwargs):\n        # if there is a language field, use it instead of the default language. This way we can have mixed language training.\n        languages = kwargs[\"language\"] if \"language\" in kwargs else [language] * len(completions)\n\n        completion_contents = [completion[0][\"content\"] for completion in completions]\n        matches = [\n            re.match(\n                rf\"^<think>\\n.*?\\n</think>\\n<answer>\\n.*?```{sample_language}.*?```.*?\\n</answer>$\",\n                content,\n                re.DOTALL | re.MULTILINE,\n            )\n            for content, sample_language in zip(completion_contents, languages)\n        ]\n        return [1.0 if match else 0.0 for match in matches]\n\n    return code_format_reward\n\n\ndef get_soft_overlong_punishment(max_completion_len, soft_punish_cache):\n    \"\"\"\n    Reward function that penalizes overlong completions. It is used to penalize overlong completions,\n    but not to reward shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476)\n\n    Args:\n        max_completion_len: Maximum length of the completion\n        soft_punish_cache: Minimum length of the completion. If set to 0, no minimum length is applied.\n    \"\"\"\n\n    def soft_overlong_punishment_reward(completion_ids: list[list[int]], **kwargs) -> list[float]:\n        \"\"\"Reward function that penalizes overlong completions.\"\"\"\n        rewards = []\n        for ids in completion_ids:\n            completion_length = len(ids)\n            if completion_length <= max_completion_len - soft_punish_cache:\n                rewards.append(0.0)\n            elif max_completion_len - soft_punish_cache < completion_length <= max_completion_len:\n                rewards.append((max_completion_len - soft_punish_cache - completion_length) / soft_punish_cache)\n            else:\n                rewards.append(-1.0)\n        return rewards\n\n    return soft_overlong_punishment_reward\n\n\ndef get_reward_funcs(script_args) -> list[Callable]:\n    REWARD_FUNCS_REGISTRY = {\n        \"accuracy\": accuracy_reward,\n        \"format\": format_reward,\n        \"reasoning_steps\": reasoning_steps_reward,\n        \"cosine\": get_cosine_scaled_reward(\n            min_value_wrong=script_args.cosine_min_value_wrong,\n            max_value_wrong=script_args.cosine_max_value_wrong,\n            min_value_correct=script_args.cosine_min_value_correct,\n            max_value_correct=script_args.cosine_max_value_correct,\n            max_len=script_args.cosine_max_len,\n        ),\n        \"repetition_penalty\": get_repetition_penalty_reward(\n            ngram_size=script_args.repetition_n_grams,\n            max_penalty=script_args.repetition_max_penalty,\n        ),\n        \"length\": len_reward,\n        \"code\": update_wrapper(\n            partial(\n                code_reward,\n                num_parallel=script_args.parallel_code_exec_per_proc,\n                provider_type=script_args.code_provider,\n                enforce_same_language=getattr(script_args, \"enforce_same_language\", False),\n            ),\n            code_reward,\n        ),\n        \"binary_code\": update_wrapper(\n            partial(\n                binary_code_reward,\n                num_parallel=script_args.parallel_code_exec_per_proc,\n                provider_type=script_args.code_provider,\n                enforce_same_language=getattr(script_args, \"enforce_same_language\", False),\n            ),\n            binary_code_reward,\n        ),\n        \"ioi_code\": update_wrapper(\n            partial(\n                ioi_code_reward,\n                test_batch_size=script_args.code_eval_test_batch_size,\n                provider_type=getattr(script_args, \"ioi_provider\", \"piston\"),\n            ),\n            ioi_code_reward,\n        ),\n        \"cf_code\": update_wrapper(\n            partial(\n                cf_code_reward,\n                test_batch_size=script_args.code_eval_test_batch_size,\n                scoring_mode=script_args.code_eval_scoring_mode,\n            ),\n            cf_code_reward,\n        ),\n        \"code_format\": get_code_format_reward(language=script_args.code_language),\n        \"tag_count\": tag_count_reward,\n        \"soft_overlong_punishment\": get_soft_overlong_punishment(\n            max_completion_len=script_args.max_completion_len,\n            soft_punish_cache=script_args.soft_punish_cache,\n        ),\n    }\n    reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]\n\n    return reward_funcs\n"
  },
  {
    "path": "src/open_r1/sft.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nSupervised fine-tuning script for decoder language models.\n\nUsage:\n\n# One 1 node of 8 x H100s\naccelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \\\n    --model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \\\n    --dataset_name open-r1/Mixture-of-Thoughts \\\n    --dataset_config all \\\n    --eos_token '<|im_end|>' \\\n    --learning_rate 4.0e-5 \\\n    --num_train_epochs 5 \\\n    --max_seq_length 32768 \\\n    --per_device_train_batch_size 2 \\\n    --gradient_checkpointing \\\n    --bf16 \\\n    --use_liger_kernel \\\n    --output_dir data/OpenR1-Distill-7B\n\"\"\"\n\nimport logging\nimport os\nimport sys\n\nimport datasets\nimport transformers\nfrom transformers import set_seed\nfrom transformers.trainer_utils import get_last_checkpoint\n\nfrom open_r1.configs import ScriptArguments, SFTConfig\nfrom open_r1.utils import get_dataset, get_model, get_tokenizer\nfrom open_r1.utils.callbacks import get_callbacks\nfrom open_r1.utils.wandb_logging import init_wandb_training\nfrom trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef main(script_args, training_args, model_args):\n    set_seed(training_args.seed)\n\n    ###############\n    # Setup logging\n    ###############\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n        handlers=[logging.StreamHandler(sys.stdout)],\n    )\n    log_level = training_args.get_process_log_level()\n    logger.setLevel(log_level)\n    datasets.utils.logging.set_verbosity(log_level)\n    transformers.utils.logging.set_verbosity(log_level)\n    transformers.utils.logging.enable_default_handler()\n    transformers.utils.logging.enable_explicit_format()\n\n    logger.info(f\"Model parameters {model_args}\")\n    logger.info(f\"Script parameters {script_args}\")\n    logger.info(f\"Training parameters {training_args}\")\n\n    # Check for last checkpoint\n    last_checkpoint = None\n    if os.path.isdir(training_args.output_dir):\n        last_checkpoint = get_last_checkpoint(training_args.output_dir)\n    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:\n        logger.info(f\"Checkpoint detected, resuming training at {last_checkpoint=}.\")\n\n    if \"wandb\" in training_args.report_to:\n        init_wandb_training(training_args)\n\n    ######################################\n    # Load dataset, tokenizer, and model #\n    ######################################\n    dataset = get_dataset(script_args)\n    tokenizer = get_tokenizer(model_args, training_args)\n    model = get_model(model_args, training_args)\n\n    if tokenizer.chat_template is None:\n        logger.info(\"No chat template provided, defaulting to ChatML.\")\n        model, tokenizer = setup_chat_format(model, tokenizer, format=\"chatml\")\n\n    ############################\n    # Initialize the SFT Trainer\n    ############################\n    trainer = SFTTrainer(\n        model=model,\n        args=training_args,\n        train_dataset=dataset[script_args.dataset_train_split],\n        eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != \"no\" else None),\n        processing_class=tokenizer,\n        peft_config=get_peft_config(model_args),\n        callbacks=get_callbacks(training_args, model_args),\n    )\n\n    ###############\n    # Training loop\n    ###############\n    logger.info(\"*** Train ***\")\n    checkpoint = None\n    if training_args.resume_from_checkpoint is not None:\n        checkpoint = training_args.resume_from_checkpoint\n    elif last_checkpoint is not None:\n        checkpoint = last_checkpoint\n    train_result = trainer.train(resume_from_checkpoint=checkpoint)\n    metrics = train_result.metrics\n    metrics[\"train_samples\"] = len(dataset[script_args.dataset_train_split])\n    trainer.log_metrics(\"train\", metrics)\n    trainer.save_metrics(\"train\", metrics)\n    trainer.save_state()\n\n    ##################################\n    # Save model and create model card\n    ##################################\n    logger.info(\"*** Save model ***\")\n    # Align the model's generation config with the tokenizer's eos token\n    # to avoid unbounded generation in the transformers `pipeline()` function\n    trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id\n    trainer.save_model(training_args.output_dir)\n    logger.info(f\"Model saved to {training_args.output_dir}\")\n\n    # Save everything else on main process\n    kwargs = {\n        \"dataset_name\": script_args.dataset_name,\n        \"tags\": [\"open-r1\"],\n    }\n    if trainer.accelerator.is_main_process:\n        trainer.create_model_card(**kwargs)\n        # Restore k,v cache for fast inference\n        trainer.model.config.use_cache = True\n        trainer.model.config.save_pretrained(training_args.output_dir)\n\n    ##########\n    # Evaluate\n    ##########\n    if training_args.do_eval:\n        logger.info(\"*** Evaluate ***\")\n        metrics = trainer.evaluate()\n        metrics[\"eval_samples\"] = len(dataset[script_args.dataset_test_split])\n        trainer.log_metrics(\"eval\", metrics)\n        trainer.save_metrics(\"eval\", metrics)\n\n    #############\n    # push to hub\n    #############\n    if training_args.push_to_hub:\n        logger.info(\"Pushing to hub...\")\n        trainer.push_to_hub(**kwargs)\n\n\nif __name__ == \"__main__\":\n    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))\n    script_args, training_args, model_args = parser.parse_args_and_config()\n    main(script_args, training_args, model_args)\n"
  },
  {
    "path": "src/open_r1/utils/__init__.py",
    "content": "from .data import get_dataset\nfrom .import_utils import is_e2b_available, is_morph_available\nfrom .model_utils import get_model, get_tokenizer\n\n\n__all__ = [\"get_tokenizer\", \"is_e2b_available\", \"is_morph_available\", \"get_model\", \"get_dataset\"]\n"
  },
  {
    "path": "src/open_r1/utils/callbacks.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport subprocess\nfrom typing import List\n\nfrom transformers import TrainerCallback\nfrom transformers.trainer_callback import TrainerControl, TrainerState\nfrom transformers.training_args import TrainingArguments\n\nfrom .evaluation import run_benchmark_jobs\nfrom .hub import push_to_hub_revision\n\n\ndef is_slurm_available() -> bool:\n    # returns true if a slurm queueing system is available\n    try:\n        subprocess.run([\"sinfo\"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n        return True\n    except FileNotFoundError:\n        return False\n\n\nclass DummyConfig:\n    def __init__(self, **kwargs):\n        for k, v in kwargs.items():\n            setattr(self, k, v)\n\n\nclass PushToHubRevisionCallback(TrainerCallback):\n    def __init__(self, model_config) -> None:\n        self.model_config = model_config\n\n    def on_save(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        if state.is_world_process_zero:\n            global_step = state.global_step\n\n            # WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken, so I do this workaround\n            # Also if you instantiate a new SFTConfig, the accelerator dist state will be broken\n            dummy_config = DummyConfig(\n                hub_model_id=args.hub_model_id,\n                hub_model_revision=f\"{args.hub_model_revision}-step-{global_step:09d}\",\n                output_dir=f\"{args.output_dir}/checkpoint-{global_step}\",\n                system_prompt=args.system_prompt,\n            )\n\n            future = push_to_hub_revision(\n                dummy_config, extra_ignore_patterns=[\"*.pt\"]\n            )  # don't push the optimizer states\n\n            if is_slurm_available():\n                dummy_config.benchmarks = args.benchmarks\n\n                def run_benchmark_callback(_):\n                    print(f\"Checkpoint {global_step} pushed to hub.\")\n                    run_benchmark_jobs(dummy_config, self.model_config)\n\n                future.add_done_callback(run_benchmark_callback)\n\n\nCALLBACKS = {\n    \"push_to_hub_revision\": PushToHubRevisionCallback,\n}\n\n\ndef get_callbacks(train_config, model_config) -> List[TrainerCallback]:\n    callbacks = []\n    for callback_name in train_config.callbacks:\n        if callback_name not in CALLBACKS:\n            raise ValueError(f\"Callback {callback_name} not found in CALLBACKS.\")\n        callbacks.append(CALLBACKS[callback_name](model_config))\n\n    return callbacks\n"
  },
  {
    "path": "src/open_r1/utils/code_providers.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Code execution providers for executing and evaluating code snippets.\"\"\"\n\nimport abc\nimport asyncio\nfrom typing import List, Optional\n\nfrom ..utils import is_e2b_available, is_morph_available\n\n\nif is_e2b_available():\n    from e2b_code_interpreter import AsyncSandbox\n    from e2b_code_interpreter.models import Execution\n\n    from .routed_sandbox import RoutedSandbox\nelse:\n    AsyncSandbox = None\n    Execution = None\n    RoutedSandbox = None\n\nif is_morph_available():\n    from morphcloud.api import MorphCloudClient\n    from morphcloud.sandbox import Sandbox\n\n    from .routed_morph import RoutedMorphSandbox\nelse:\n    MorphCloudClient = None\n    Sandbox = None\n    RoutedMorphSandbox = None\n\n\nclass CodeExecutionProvider(abc.ABC):\n    \"\"\"Abstract base class for code execution providers.\"\"\"\n\n    @abc.abstractmethod\n    def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:\n        \"\"\"Execute multiple scripts and return their reward values.\n\n        Args:\n            scripts: List of code scripts to execute\n            language: The programming language of the scripts\n\n        Returns:\n            List of float rewards (one per script)\n        \"\"\"\n        pass\n\n\nclass E2BProvider(CodeExecutionProvider):\n    \"\"\"Provider that executes code using E2B sandboxes.\"\"\"\n\n    def __init__(self, num_parallel: int = 2, e2b_router_url: Optional[str] = None):\n        \"\"\"Initialize the E2B provider.\n\n        Args:\n            num_parallel: Number of parallel sandboxes to use\n            e2b_router_url: URL for the E2B router (if using router mode)\n        \"\"\"\n        if not is_e2b_available():\n            raise ImportError(\n                \"E2B is not available and required for this provider. Please install E2B with \"\n                \"`pip install e2b-code-interpreter` and add an API key to a `.env` file.\"\n            )\n\n        self.num_parallel = num_parallel\n        self.e2b_router_url = e2b_router_url\n\n    def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:\n        \"\"\"Execute scripts using E2B sandboxes.\n\n        If e2b_router_url is provided, uses the RoutedSandbox for batch processing.\n        Otherwise, uses direct AsyncSandbox with parallelization.\n        \"\"\"\n        if self.e2b_router_url is not None:\n            routed_sandbox = RoutedSandbox(router_url=self.e2b_router_url)\n\n            executions = routed_sandbox.run_code(\n                scripts=scripts,\n                languages=languages,\n                timeout=30,\n                request_timeout=28,\n            )\n\n            rewards = []\n            for execution in executions:\n                try:\n                    reward = float(execution.text)\n                    rewards.append(reward)\n                except Exception:\n                    rewards.append(None)\n            return rewards\n\n        try:\n            rewards = self._run_async_from_sync(scripts, languages, self.num_parallel)\n        except Exception as e:\n            print(f\"Error from E2B executor: {e}\")\n            rewards = [0.0] * len(scripts)\n\n        return rewards\n\n    def _run_async_from_sync(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:\n        \"\"\"Function wrapping the `_run_async` function.\"\"\"\n        try:\n            rewards = asyncio.run(self._run_async(scripts, languages, num_parallel))\n        except Exception as e:\n            print(f\"Error from E2B executor async: {e}\")\n            raise e\n\n        return rewards\n\n    async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:\n        semaphore = asyncio.Semaphore(num_parallel)\n\n        tasks = [self._run_script(script, languages, semaphore) for script in scripts]\n\n        results = await asyncio.gather(*tasks)\n        rewards = list(results)\n\n        return rewards\n\n    async def _run_script(self, script: str, languages: List[str], semaphore: asyncio.Semaphore) -> float:\n        # We set a timeout margin, as the AsyncSandbox timeout does not seem to work\n        # These values are based on running 256 examples with the gold solution\n        # from open-r1/verifiable-coding-problems-python_decontaminated\n        # see scripts/benchmark_e2b.py\n\n        SANDBOX_TIMEOUT = 30\n        MARGIN = 2\n        REQUEST_TIMEOUT = SANDBOX_TIMEOUT - MARGIN\n        ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN\n\n        async with semaphore:\n            try:\n                sandbox = await AsyncSandbox.create(timeout=SANDBOX_TIMEOUT, request_timeout=REQUEST_TIMEOUT)\n                execution = await asyncio.wait_for(\n                    sandbox.run_code(script, languages=languages),\n                    timeout=ASYNCIO_TIMEOUT,\n                )\n                return float(execution.text)\n            except (TypeError, ValueError):\n                return 0.0\n            except asyncio.TimeoutError:\n                print(\"Operation timed out\")\n                return 0.0\n            except Exception as e:\n                print(f\"Error in `_run_script` from E2B sandbox ID {sandbox.sandbox_id} : {e}\")\n                return 0.0\n            finally:\n                try:\n                    await sandbox.kill()\n                except Exception as e:\n                    print(f\"Error from E2B executor kill with sandbox ID {sandbox.sandbox_id} : {e}\")\n\n\nclass MorphProvider(CodeExecutionProvider):\n    \"\"\"Provider that executes code using MorphCloud's Sandbox API.\"\"\"\n\n    def __init__(self, num_parallel: int = 2, morph_router_url: Optional[str] = None):\n        \"\"\"Initialize the Morph provider.\n\n        Args:\n            num_parallel: Number of parallel executions to use\n            morph_router_url: URL for the MorphCloud router (if using router mode)\n        \"\"\"\n        if not is_morph_available():\n            raise ImportError(\n                \"MorphCloud is not available and required for this provider. Please install MorphCloud with \"\n                \"`pip install morphcloud` and add an API key to a `.env` file.\"\n            )\n\n        try:\n            from dotenv import load_dotenv\n\n            load_dotenv()\n        except ImportError:\n            print(\"Warning: python-dotenv not installed. Environment variables must be set directly.\")\n\n        self.num_parallel = num_parallel\n        self.morph_router_url = morph_router_url\n\n        if self.morph_router_url is not None:\n            self.routed_sandbox = RoutedMorphSandbox(router_url=self.morph_router_url)\n            return\n\n        import os\n\n        self.api_key = os.getenv(\"MORPH_API_KEY\")\n        if not self.api_key:\n            raise ValueError(\"MorphCloud API key not found. Please set the MORPH_API_KEY environment variable.\")\n\n        try:\n            self.client = MorphCloudClient(api_key=self.api_key)\n            self.Sandbox = Sandbox\n        except ImportError as e:\n            raise ImportError(f\"Required MorphCloud dependencies not installed: {e}\")\n\n    def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:\n        \"\"\"Execute scripts using MorphCloud Sandbox API.\n\n        Args:\n            scripts: List of Python scripts to execute\n            language: Programming language\n\n        Returns:\n            List of float rewards (one per script)\n        \"\"\"\n\n        if hasattr(self, \"routed_sandbox\"):\n            try:\n                results = self.routed_sandbox.run_code(\n                    scripts=scripts,\n                    languages=languages,\n                    timeout=90,\n                    request_timeout=96,\n                )\n\n                rewards = []\n                for result in results:\n                    try:\n                        reward = float(result.text)\n                        rewards.append(reward)\n                    except (ValueError, AttributeError):\n                        rewards.append(0.0)\n                return rewards\n            except Exception as e:\n                print(f\"Error from MorphCloud router: {e}\")\n                return [0.0] * len(scripts)\n\n        import asyncio\n\n        try:\n            rewards = asyncio.run(self._run_async(scripts, languages, self.num_parallel))\n        except Exception as e:\n            print(f\"Error from MorphCloud executor: {e}\")\n            rewards = [0.0] * len(scripts)\n\n        return rewards\n\n    async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:\n        \"\"\"Run multiple scripts concurrently with limited parallelism.\n\n        Args:\n            scripts: List of scripts to execute\n            language: Programming language\n            num_parallel: Maximum number of concurrent executions\n\n        Returns:\n            List of rewards\n        \"\"\"\n\n        semaphore = asyncio.Semaphore(num_parallel)\n\n        tasks = [self._run_script(script, languages, semaphore) for script in scripts]\n\n        results = await asyncio.gather(*tasks)\n\n        return list(results)\n\n    async def _run_script(self, script: str, languages: List[str], semaphore: asyncio.Semaphore) -> float:\n        \"\"\"Execute a single script in a MorphCloud Sandbox.\n\n        Args:\n            script: The script to execute\n            language: Programming language\n            semaphore: Semaphore to limit concurrency\n\n        Returns:\n            Float reward from script execution\n        \"\"\"\n        SANDBOX_TIMEOUT = 90\n        MARGIN = 6\n        ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN\n\n        sandbox = None\n        async with semaphore:\n            try:\n                sandbox = await asyncio.to_thread(self.Sandbox.new, client=self.client, ttl_seconds=SANDBOX_TIMEOUT)\n                result = await asyncio.wait_for(\n                    asyncio.to_thread(\n                        sandbox.run_code,\n                        script,\n                        languages=languages,\n                        timeout=SANDBOX_TIMEOUT,\n                    ),\n                    timeout=ASYNCIO_TIMEOUT,\n                )\n\n                reward = 0.0\n                try:\n                    if hasattr(result, \"text\") and result.text:\n                        lines = result.text.strip().split(\"\\n\")\n                        if lines:\n                            try:\n                                reward = float(lines[-1])\n                            except ValueError:\n                                try:\n                                    reward = float(result.text.strip())\n                                except ValueError:\n                                    pass\n                    elif hasattr(result, \"stdout\") and result.stdout:\n                        lines = result.stdout.strip().split(\"\\n\")\n                        if lines:\n                            try:\n                                reward = float(lines[-1])\n                            except ValueError:\n                                pass\n                except (ValueError, AttributeError):\n                    pass\n\n                return reward\n\n            except asyncio.TimeoutError:\n                return 0.0\n            except Exception:\n                return 0.0\n            finally:\n                if sandbox:\n                    try:\n                        await asyncio.to_thread(sandbox.close)\n                        await asyncio.to_thread(sandbox.shutdown)\n                    except Exception:\n                        pass\n\n\ndef get_provider(provider_type: str = \"e2b\", **kwargs) -> CodeExecutionProvider:\n    \"\"\"Factory function to get the appropriate code execution provider.\n\n    Args:\n        provider_type: Type of provider to use (\"e2b\", \"morph\")\n        **kwargs: Additional arguments to pass to the provider\n\n    Returns:\n        An instance of CodeExecutionProvider\n    \"\"\"\n    num_parallel = kwargs.pop(\"num_parallel\", 2)\n\n    if provider_type == \"e2b\":\n        # Extract E2B-specific arguments\n        e2b_router_url = kwargs.pop(\"e2b_router_url\", None)\n        return E2BProvider(\n            num_parallel=num_parallel,\n            e2b_router_url=e2b_router_url,\n        )\n    elif provider_type == \"morph\":\n        # Extract Morph-specific arguments\n        morph_router_url = kwargs.pop(\"morph_router_url\", None)\n        return MorphProvider(\n            num_parallel=num_parallel,\n            morph_router_url=morph_router_url,\n        )\n    else:\n        raise ValueError(f\"Unknown provider type: {provider_type}\")\n"
  },
  {
    "path": "src/open_r1/utils/competitive_programming/__init__.py",
    "content": "from .cf_scoring import score_submission\nfrom .code_patcher import patch_code\nfrom .ioi_scoring import SubtaskResult, score_subtask, score_subtasks\nfrom .ioi_utils import add_includes\nfrom .morph_client import get_morph_client_from_env\nfrom .piston_client import get_piston_client_from_env, get_slurm_piston_endpoints\n\n\n__all__ = [\n    \"get_piston_client_from_env\",\n    \"get_slurm_piston_endpoints\",\n    \"get_morph_client_from_env\",\n    \"patch_code\",\n    \"score_submission\",\n    \"score_subtask\",\n    \"score_subtasks\",\n    \"add_includes\",\n    \"SubtaskResult\",\n]\n"
  },
  {
    "path": "src/open_r1/utils/competitive_programming/cf_scoring.py",
    "content": "import asyncio\nimport os\nfrom io import BytesIO\nfrom typing import Literal\n\nfrom async_lru import alru_cache\n\nfrom .piston_client import PistonClient\nfrom .utils import batched\n\n\nasync def score_single_test_case(\n    client: PistonClient,\n    problem_data: dict,\n    test_input: str,\n    test_output: str,\n    submission: str,\n    submission_language: str = \"cpp\",\n) -> tuple[str, str]:\n    if submission_language not in [\"python\", \"cpp\"]:\n        raise ValueError(f\"Invalid submission language: {submission_language}\")\n    try:\n        result = await client.send_execute(\n            {\n                \"files\": [\n                    {\"name\": f\"main.{submission_language}\", \"content\": submission},\n                    *(\n                        [{\"name\": \"checker.py\", \"content\": problem_data[\"generated_checker\"]}]\n                        if problem_data[\"generated_checker\"]\n                        else []\n                    ),\n                    {\"name\": \"input.txt\", \"content\": test_input},\n                    {\"name\": \"correct_output.txt\", \"content\": test_output},\n                    {\n                        \"name\": \"grader_config\",\n                        \"content\": \"\\n\".join(\n                            f\"{key}={value}\"\n                            for key, value in {\n                                \"TIME_LIMIT\": problem_data[\"time_limit\"],\n                                \"MEMORY_LIMIT\": problem_data[\"memory_limit\"],\n                                \"INPUT_MODE\": problem_data[\"input_mode\"],\n                            }.items()\n                        ),\n                    },\n                ],\n                \"run_timeout\": (problem_data[\"time_limit\"] + 10) * 1000,\n                # +10 seconds hard limit. time limits are handled by the codeforces script\n            },\n            language=\"cf_python3\" if submission_language == \"python\" else \"c++17\",\n        )\n    except Exception as e:\n        print(f\"Error scoring submission: {e}\")\n        return False\n\n    return result\n\n\n@alru_cache(maxsize=32)  # TODO make this configurable\nasync def get_generated_contest_tests(contest_id: str) -> list[dict]:\n    import pandas as pd\n\n    import aiofiles\n    import aiofiles.os\n\n    tests_folder = os.environ.get(\"CF_TESTS_FOLDER\", None)\n    if not tests_folder:\n        raise ValueError(\n            \"CF_TESTS_FOLDER environment variable not set! Please download the codeforces generated tests and set CF_TESTS_FOLDER to the folder path. See https://huggingface.co/datasets/open-r1/codeforces for more information.\"\n        )\n    if not await aiofiles.os.path.exists(tests_folder):\n        raise ValueError(\n            f\"CF_TESTS_FOLDER path '{tests_folder}' does not exist! Please download the codeforces generated tests and set CF_TESTS_FOLDER to the folder path. See https://huggingface.co/datasets/open-r1/codeforces for more information.\"\n        )\n    parquet_path = os.path.join(tests_folder, f\"test_cases_{int(contest_id):04d}.parquet\")\n    if not await aiofiles.os.path.exists(parquet_path):\n        return {}\n\n    # Read parquet file asynchronously\n    async with aiofiles.open(parquet_path, \"rb\") as f:\n        content = await f.read()\n        df = pd.read_parquet(BytesIO(content))\n\n    # Group by problem_id and convert to dictionary of lists\n    grouped_tests = df.groupby(\"problem_id\").apply(lambda x: x[[\"input\", \"output\"]].to_dict(\"records\")).to_dict()\n\n    return grouped_tests\n\n\nasync def get_generated_tests(problem_id: str) -> list[dict]:\n    contest_id = problem_id.split(\"/\")[0]\n    return (await get_generated_contest_tests(contest_id)).get(problem_id, [])\n\n\nasync def score_submission(\n    client: PistonClient,\n    problem_data: dict,\n    submission: str,\n    test_batch_size: int = 1,\n    scoring_mode: Literal[\"pass_fail\", \"partial\", \"weighted_sum\"] = \"weighted_sum\",\n    no_compile_reward: float = -0.1,\n    no_submission_reward: float = -1.0,\n    submission_language: str = \"cpp\",\n) -> float:\n    if submission_language not in [\"python\", \"cpp\"]:\n        raise ValueError(f\"Invalid submission language: {submission_language}\")\n    test_cases = problem_data[\"official_tests\"] + (await get_generated_tests(problem_data[\"id\"]))\n    # invalid/not a coding problem\n    if test_cases is None or len(test_cases) == 0:\n        return None\n    # no code extracted\n    if not submission:\n        return no_submission_reward\n\n    passed_test_cases = 0\n    # run one batch, check if any of them failed (0 score): if so stop evaluating (assuming non partial score); otherwise continue with the next batch of test cases.\n    for test_batch_to_run in batched(test_cases, test_batch_size) if test_batch_size >= 1 else [test_cases]:\n        results = await asyncio.gather(\n            *[\n                asyncio.create_task(\n                    score_single_test_case(\n                        client, problem_data, test_case[\"input\"], test_case[\"output\"], submission, submission_language\n                    )\n                )\n                for test_case in test_batch_to_run\n            ]\n        )\n        if any(result and result[\"compile\"][\"code\"] != 0 for result in results):\n            return no_compile_reward\n\n        tests_passed_results = [\n            result and result[\"run\"][\"code\"] == 0 and result[\"run\"][\"stdout\"].strip() == \"1\" for result in results\n        ]\n        if scoring_mode == \"pass_fail\" and any(not test_passed for test_passed in tests_passed_results):\n            break\n        passed_test_cases += sum(1 for test_passed in tests_passed_results if test_passed)\n\n    pass_fail_score = 1.0 if passed_test_cases == len(test_cases) else 0.0\n\n    if scoring_mode == \"pass_fail\":\n        return pass_fail_score\n    elif scoring_mode == \"partial\":\n        return passed_test_cases / len(test_cases)\n    elif scoring_mode == \"weighted_sum\":\n        return pass_fail_score + 0.1 * (passed_test_cases / len(test_cases))\n    else:\n        raise ValueError(f\"Invalid scoring mode: {scoring_mode}\")\n"
  },
  {
    "path": "src/open_r1/utils/competitive_programming/code_patcher.py",
    "content": "import re\n\n\ndef fix_python3_imports(source_code):\n    \"\"\"\n    Fix common import and function changes between Python 3 versions\n\n    Args:\n        source_code (str): The Python source code to update\n\n    Returns:\n        str: The updated source code\n    \"\"\"\n    # Dictionary of patterns to replacements\n    replacements = [\n        # Fix collections.abc imports (changed in Python 3.3+)\n        (\n            r\"from collections import (Mapping|Sequence|Set|Container|MutableMapping|MutableSet|MutableSequence)\",\n            r\"from collections.abc import \\1\",\n        ),\n        # Fix imp module deprecation (deprecated in 3.4)\n        (r\"import imp\", r\"import importlib\"),\n        # Fix asyncio.async() to asyncio.ensure_future() (renamed in 3.4.4)\n        (r\"asyncio\\.async\\(\", r\"asyncio.ensure_future(\"),\n        # Fix inspect.getargspec to inspect.getfullargspec (deprecated in 3.5)\n        (r\"inspect\\.getargspec\", r\"inspect.getfullargspec\"),\n        # Fix array.array 'c' type code to 'b' (removed in 3.9)\n        (r\"array\\.array\\('c'\", r\"array.array('b'\"),\n        # Fix backslash line continuation with multiple newlines (Python-specific issue)\n        (r\"\\\\(\\r\\n|\\r|\\n)+\", \"\\\\\\n\"),\n        # some solutions use getlogin() to check if they are debugging or on an actual submission\n        (r\"(?:os\\s*\\.\\s*)?getlogin\\s*\\(\\s*\\)\", \"False\"),\n        # Fix usage of fractions.gcd (moved to math in 3.5)\n        # 1. Fix direct usage: fractions.gcd -> math.gcd\n        (r\"\\bfractions\\.gcd\\b\", r\"math.gcd\"),\n        # 2. Fix 'from fractions import gcd, X' -> 'from fractions import X' (start/middle)\n        (r\"(from\\s+fractions\\s+import\\s+(?:\\([^)]*)?)\\bgcd\\s*,\\s*\", r\"\\1\"),\n        # 3. Fix 'from fractions import X, gcd' -> 'from fractions import X' (end)\n        (r\"(from\\s+fractions\\s+import\\s+.*?\\S)\\s*,\\s*\\bgcd(\\s*\\)?\\s*(?:#.*)?)\", r\"\\1\\2\"),\n        # 4. Fix standalone 'from fractions import gcd' -> 'from math import gcd'\n        (r\"from\\s+fractions\\s+import\\s+\\(?\\s*gcd\\s*\\)?\", r\"\"),\n        # --- End: Replacement for the faulty line ---\n    ]\n\n    lines = source_code.splitlines()\n    last_import = max(\n        [\n            i\n            for i, line in enumerate(lines)\n            if line.strip().startswith(\"import\") or (line.strip().startswith(\"from\") and \"import\" in line)\n        ],\n        default=0,\n    )\n    import_section = \"\\n\".join(lines[: last_import + 1])\n    main_source = \"\\n\".join(lines[last_import:])\n\n    if \"fractions.gcd\" in source_code and \"import math\" not in source_code:\n        import_section += \"\\nimport math\"\n    elif \"gcd\" in source_code and \"from math import gcd\" not in source_code:\n        import_section += \"\\nfrom math import gcd\"\n\n    if \"set_int_max_str_digits\" not in source_code:\n        import_section += \"\\nimport sys\\nsys.set_int_max_str_digits(0)\"\n\n    source_code = import_section + \"\\n\" + main_source\n\n    # Apply each replacement\n    for pattern, replacement in replacements:\n        source_code = re.sub(pattern, replacement, source_code)\n\n    source_code = source_code.rstrip(\"\\\\\")\n\n    return source_code\n\n\ndef fix_cpp_includes(source_code):\n    # has most of the useful functions\n    code_header = \"#include <bits/stdc++.h>\\n\"\n    # use namespace std since models forget std:: often\n    if \"using namespace std;\" not in source_code and \"std::\" not in source_code:\n        code_header += \"\\nusing namespace std;\\n\\n\"\n    return code_header + source_code\n\n\ndef is_patchable(lang):\n    return lang in (\"python\", \"python3\", \"Python 3\", \"PyPy 3\", \"PyPy 3-64\", \"cpp\") or \"C++\" in lang\n\n\ndef patch_code(text, lang):\n    if not text:\n        return text\n    if lang in (\"python\", \"python3\", \"Python 3\", \"PyPy 3\", \"PyPy 3-64\"):\n        return fix_python3_imports(text)\n    elif \"cpp\" in lang or \"C++\" in lang:\n        return fix_cpp_includes(text)\n    return text\n\n\ntests = [\n    \"\"\"read = lambda: map(int, input().split())\nn, m, z = read()\nfrom fractions import gcd\nans = z // (n * m // gcd(n, m))\nprint(ans)\"\"\",\n    \"\"\"from fractions import Fraction,gcd\n\na,b,c,d = [int(x) for x in input().split()]\n\nif a*d > b*c:\n    num = a*d-b*c\n    denom = a*d\nelse:\n    num = b*c-a*d\n    denom = b*c\ndiv = gcd(num,denom)\nprint('%d/%d'%(num//div,denom//div))\"\"\",\n]\n\nif __name__ == \"__main__\":\n    for test in tests:\n        print(\"ORIGINAL:\", test, sep=\"\\n\\n\")\n        print(\"PATCHED:\", patch_code(test, \"Python 3\"), sep=\"\\n\\n\")\n        print(\"=\" * 50)\n"
  },
  {
    "path": "src/open_r1/utils/competitive_programming/ioi_scoring.py",
    "content": "import asyncio\nfrom dataclasses import asdict, dataclass, field\nfrom typing import Union\n\nfrom .ioi_utils import load_ioi_tests\nfrom .piston_client import PistonClient, PistonError\nfrom .utils import batched\n\n\n@dataclass\nclass TestResult:\n    \"\"\"\n    Represents the result of a single test case execution.\n\n    Attributes:\n        test_name: Name of the test case\n        score: Score achieved for this test (0.0 to 1.0)\n        status: Status code of the test result (e.g., 'AC', 'WA', 'TLE')\n        feedback: Detailed feedback message from the judge or an error message\n    \"\"\"\n\n    test_name: str\n    score: float = 0.0\n    status: str = \"SKIPPED\"\n    feedback: str = None\n\n\n@dataclass\nclass SubtaskResult:\n    \"\"\"\n    Represents the result of a subtask containing multiple test cases.\n\n    Attributes:\n        problem: Problem identifier\n        subtask: Subtask identifier\n        points: Maximum points available for this subtask\n        score_precision: Number of decimal places for score rounding\n        test_results: List of individual test case results\n    \"\"\"\n\n    problem: str = None\n    subtask: str = None\n\n    points: float = 0.0\n    score_precision: int = 2\n\n    test_results: list[TestResult] = field(default_factory=list)\n\n    @property\n    def status(self):\n        \"\"\"\n        Determines the overall status of the subtask based on the worst status among test results.\n        Status priorities are ordered from worst to best.\n\n        Returns:\n            str: The status with the highest priority (lowest value)\n        \"\"\"\n        status_prios = {\"CE\": -1, \"RE\": 0, \"WA\": 1, \"MLE\": 2, \"TLE\": 3, \"PA\": 4, \"AC\": 5, \"SKIPPED\": 999}\n        return min([x.status for x in self.test_results], key=lambda x: status_prios[x])\n\n    @property\n    def score(self):\n        \"\"\"\n        Calculates the raw score for the subtask as the minimum score across all test results.\n\n        Returns:\n            float: The rounded minimum score\n        \"\"\"\n        return (\n            0\n            if not self.test_results\n            else round(min([test_result.score for test_result in self.test_results]), self.score_precision)\n        )\n\n    @property\n    def weighted_score(self):\n        \"\"\"\n        Calculates the weighted score by multiplying the raw score by the available points.\n\n        Returns:\n            float: The rounded weighted score\n        \"\"\"\n        return (\n            0\n            if not self.test_results\n            else round(\n                min([test_result.score for test_result in self.test_results]) * self.points, self.score_precision\n            )\n        )\n\n    def to_dict(self):\n        \"\"\"\n        Converts the SubtaskResult to a dictionary representation.\n\n        Returns:\n            dict: Dictionary containing all subtask result data\n        \"\"\"\n        return {\n            \"problem\": self.problem,\n            \"subtask\": self.subtask,\n            \"score\": self.score,\n            \"weighted_score\": self.weighted_score,\n            \"points\": self.points,\n            \"score_precision\": self.score_precision,\n            \"status\": self.status,\n            \"test_results\": [asdict(test_result) for test_result in self.test_results],\n        }\n\n\ndef _extract_single_status(score: float, feedback: str) -> str:\n    \"\"\"\n    Determines the status code based on the score and feedback message.\n\n    Args:\n        score: The numeric score (0.0 to 1.0)\n        feedback: The feedback message from the execution\n\n    Returns:\n        str: Status code ('CE', 'MLE', 'TLE', 'WA', 'RE', 'AC', or 'PA')\n    \"\"\"\n    if score == 0.0:\n        if \"Compilation error\" in feedback:\n            return \"CE\"\n        elif \"Memory limit exceeded\" in feedback:\n            return \"MLE\"\n        elif \"Time limit exceeded\" in feedback:\n            return \"TLE\"\n        elif \"Output isn't correct\" in feedback:\n            return \"WA\"\n        else:\n            return \"RE\"\n    elif score == 1.0:\n        return \"AC\"\n    else:\n        return \"PA\"\n\n\nasync def score_single_test_case(\n    client: PistonClient, subtask: dict, test_name: str, test_input: str, test_output: str, submission: str\n) -> TestResult:\n    \"\"\"\n    Scores a single test case by running the submission against the provided input and output.\n\n    Args:\n        client: PistonClient instance for executing code\n        subtask: Dictionary containing subtask configuration\n        test_name: Name of the test case\n        test_input: Input data for the test case\n        test_output: Expected output for the test case\n        submission: Source code of the submission\n\n    Returns:\n        TestResult: Result of the test case execution\n    \"\"\"\n    # Run submission for this test case\n    score, feedback = await run_submission(client, subtask, test_input, submission, test_output)\n    score = float(score)\n\n    return TestResult(\n        test_name=test_name, score=score, status=_extract_single_status(score, feedback), feedback=feedback\n    )\n\n\nasync def score_subtask(\n    client: PistonClient,\n    subtask: dict,\n    submission: str,\n    test_case_run_cache: Union[dict, None] = None,\n    test_batch_size: int = 1,\n) -> SubtaskResult:\n    \"\"\"\n    Scores all test cases in a subtask.\n\n    Args:\n        client: PistonClient instance for executing code\n        subtask: Dictionary containing subtask configuration\n        test_cases: Dictionary mapping test names to (input, output) tuples\n        submission: Source code of the submission\n        test_case_run_cache: Optional cache of previously run test cases\n        test_batch_size: evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases.\n        -1 to evaluate all test cases in parallel\n    Returns:\n        SubtaskResult: Result of the subtask evaluation\n    \"\"\"\n    subtask_result = SubtaskResult(\n        problem=subtask[\"id\"],\n        subtask=subtask[\"subtask\"],\n        points=subtask[\"score\"],\n        score_precision=subtask[\"score_precision\"],\n        test_results=[],\n    )\n\n    # tests that are not cached\n    tests_to_run = [\n        (ti, test_name)\n        for ti, test_name in enumerate(subtask[\"test_names\"])\n        if test_case_run_cache is None or test_name not in test_case_run_cache\n    ]\n\n    # initialize test results with cached results or empty (SKIPPED) TestResult objects\n    subtask_result.test_results = [\n        test_case_run_cache[test_name]\n        if test_case_run_cache is not None and test_name in test_case_run_cache\n        else TestResult(test_name=test_name)\n        for test_name in subtask[\"test_names\"]\n    ]\n\n    # we skip submissions where no code was extracted\n    # no need to do anything, as we have a failed cached result\n    if not submission or any(\n        test_result.status != \"SKIPPED\" and test_result.score == 0.0 for test_result in subtask_result.test_results\n    ):\n        return subtask_result\n\n    if \"test_cases\" in subtask:\n        test_cases = subtask[\"test_cases\"]\n        if isinstance(subtask[\"test_cases\"], list):\n            test_cases = {test_name: test for test_name, test in zip(subtask[\"test_names\"], subtask[\"test_cases\"])}\n    else:\n        test_cases = load_ioi_tests(subtask[\"year\"], subtask[\"id\"])\n\n    # run one batch, check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases.\n    for test_batch_to_run in batched(tests_to_run, test_batch_size):\n        results = await asyncio.gather(\n            *[\n                asyncio.create_task(\n                    score_single_test_case(\n                        client, subtask, test_name, test_cases[test_name][0], test_cases[test_name][1], submission\n                    )\n                )\n                for _, test_name in test_batch_to_run\n            ]\n        )\n        for (ti, test_name), test_result in zip(test_batch_to_run, results):\n            if test_case_run_cache is not None:\n                test_case_run_cache[test_name] = test_result\n            subtask_result.test_results[ti] = test_result\n\n        # Stop early if it failed\n        if any(test_result.score == 0.0 for test_result in results):\n            break\n\n    return subtask_result\n\n\nasync def score_subtasks(\n    client: PistonClient, subtasks: list[dict], submission: str, skip_mode: bool = True\n) -> list[SubtaskResult]:\n    \"\"\"\n    Scores multiple subtasks for a submission.\n\n    Args:\n        client: PistonClient instance for executing code\n        subtasks: List of dictionaries containing subtask configurations\n        submission: Source code of the submission\n        skip_mode: If True, evaluates test by test and stops after the first failure. Otherwise, runs all tests in parallel. Should be True when evaluating a large number of submissions.\n\n    Returns:\n        list[SubtaskResult]: Results for all subtasks\n    \"\"\"\n    # avoid rerunning tests present in multiple subtasks\n    test_case_run_cache = {}\n\n    return [await score_subtask(client, subtask, submission, test_case_run_cache, skip_mode) for subtask in subtasks]\n\n\nasync def run_submission(\n    client: PistonClient, problem: dict, test_input: str, submission: str, test_output: str | None = None\n) -> tuple[str, str]:\n    \"\"\"\n    Executes a submission against a test case using the Piston execution environment.\n\n    Args:\n        client: PistonClient instance for executing code\n        problem: Dictionary containing problem configuration\n        test_input: Input data for the test case\n        submission: Source code of the submission\n        test_output: Optional expected output for the test case\n\n    Returns:\n        tuple[str, str]: A tuple containing (score, feedback)\n    \"\"\"\n    data = {\n        \"files\": [\n            # the actual submission\n            {\"name\": f\"graders/{problem['id'].lower()}.cpp\", \"content\": submission},\n            # pass the input\n            {\"name\": \"input.txt\", \"content\": test_input},\n            # pass the expected output\n            *([{\"name\": \"correct_output.txt\", \"content\": test_output}] if test_output else []),\n            # grader files\n            *({\"name\": name, \"content\": content} for name, content in problem[\"grader_files\"] if content),\n        ],\n        \"run_timeout\": round(\n            (problem[\"time_limit\"] + 3) * 1000\n        ),  # +3 seconds hard limit. time limits are handled by the ioi script\n        \"run_memory_limit\": problem[\"memory_limit\"],\n    }\n    return await execute_ioi(client, data)\n\n\nasync def execute_ioi(client, data) -> tuple[str, str]:\n    \"\"\"\n    Requests to the IOI package return the score as a float in the stdout, as well as optional feedback/errors in stderr.\n    Returns a tuple of (score, feedback).\n    \"\"\"\n    response = await client.send_execute(data)\n\n    if \"message\" in response:\n        raise PistonError(response[\"message\"])\n\n    if \"compile\" in response and response[\"compile\"][\"code\"] != 0:\n        return \"0\", \"Compilation error exit code \" + str(response[\"compile\"][\"code\"]) + \"\\n\" + response[\"compile\"][\n            \"stderr\"\n        ]\n\n    if \"run\" not in response:\n        raise PistonError(response)\n\n    if response[\"run\"][\"code\"] == 1 and \"MemoryError\" in response[\"run\"][\"stderr\"]:\n        return \"0\", \"Memory limit exceeded\"\n\n    # successful result\n    if response[\"run\"][\"stdout\"]:\n        return response[\"run\"][\"stdout\"], response[\"run\"][\"stderr\"]\n\n    if response[\"run\"][\"signal\"] == \"SIGKILL\":\n        return \"0\", \"Time limit exceeded\"\n\n    # other issues\n    if response[\"run\"][\"code\"] != 0:\n        raise PistonError(\n            f\"language={response['language']}, version={response['version']}, exit code={response['run']['code']}, stderr={response['run']['stderr']}, signal={response['run']['signal']}\"\n        )\n    return \"0\", \"Unknown error\"\n"
  },
  {
    "path": "src/open_r1/utils/competitive_programming/ioi_utils.py",
    "content": "from collections import defaultdict\nfrom functools import lru_cache\n\nfrom datasets import load_dataset\n\n\ndef add_includes(code: str, problem_id: str) -> str:\n    \"\"\"\n    Fix common compilation errors for IOI problems.\n    \"\"\"\n    if not code:\n        return code\n    # has most of the useful functions\n    code_header = \"#include <bits/stdc++.h>\\n\"\n    # include the problem header\n    problem_header_include = f'#include \"{problem_id}.h\"'\n    if problem_header_include not in code:\n        code_header += problem_header_include + \"\\n\"\n    # use namespace std since models forget std:: often\n    if \"using namespace std;\" not in code and \"std::\" not in code:\n        code_header += \"\\nusing namespace std;\\n\\n\"\n    return code_header + code\n\n\n@lru_cache\ndef load_ioi_tests_for_year(year: int) -> dict[str, dict[str, tuple[str, str]]]:\n    \"\"\"\n    Load IOI tests for a given year.\n    \"\"\"\n    tests_dataset = load_dataset(\"open-r1/ioi-test-cases\", name=f\"{year}\", split=\"train\")\n    test_cases = defaultdict(dict)\n    for test_case in tests_dataset:\n        test_cases[test_case[\"problem_id\"]][test_case[\"test_name\"]] = test_case[\"test_input\"], test_case[\"test_output\"]\n    return test_cases\n\n\ndef load_ioi_tests(year: int, problem_id: str) -> dict[str, tuple[str, str]]:\n    \"\"\"\n    Load IOI tests for a given year and problem id.\n    \"\"\"\n    return load_ioi_tests_for_year(year)[problem_id]\n"
  },
  {
    "path": "src/open_r1/utils/competitive_programming/morph_client.py",
    "content": "import asyncio\nimport json\nimport logging\nimport os\nimport tempfile\nfrom typing import Any, Dict, Optional, Tuple\n\nfrom dotenv import load_dotenv\nfrom open_r1.utils.import_utils import is_morph_available\n\n\n# Replace direct imports with conditional imports\nif is_morph_available():\n    from morphcloud.api import Instance, InstanceExecResponse, MorphCloudClient\nelse:\n    Instance = None\n    InstanceExecResponse = None\n    MorphCloudClient = None\n\n\n# Silence verbose logs from dependencies\nlogging.getLogger(\"paramiko\").setLevel(logging.ERROR)\nlogging.getLogger(\"httpx\").setLevel(logging.ERROR)\n\n\nclass MorphCloudError(Exception):\n    pass\n\n\nclass MorphCloudExecutionClient:\n    def __init__(\n        self,\n        api_key: Optional[str] = None,\n        base_url: Optional[str] = None,\n        spans_log_path: Optional[str] = None,\n    ):\n        \"\"\"\n        Initialize the MorphCloud execution client.\n\n        Args:\n            api_key: Optional API key for MorphCloud. If not provided, will use MORPH_API_KEY env var.\n            base_url: Optional base URL for MorphCloud API. If not provided, will use default.\n            spans_log_path: Path to log API call spans to. Defaults to 'logs/morph_api_spans.jsonl'.\n        \"\"\"\n\n        self.client = MorphCloudClient(api_key=api_key, base_url=base_url)\n        self._snapshot_lock = asyncio.Lock()\n\n    async def _prepare_instance(self, snapshot_id=None) -> Instance:\n        \"\"\"\n        Prepare and start a MorphCloud instance.\n\n        Args:\n          snapshot_id: Optional snapshot ID to use. If None, will get or create base snapshot.\n\n        Returns:\n          Instance: The ready-to-use MorphCloud instance\n\n        Raises:\n          TimeoutError: If instance fails to start or become ready\n        \"\"\"\n\n        if not snapshot_id:\n            snapshot = await self._get_or_create_base_snapshot()\n            snapshot_id = snapshot.id\n\n        try:\n            instance = await self.client.instances.astart(\n                snapshot_id, ttl_seconds=600\n            )  # Auto-terminate after 10 minutes\n            await instance.await_until_ready(timeout=300)\n            return instance\n        except asyncio.TimeoutError as e:\n            print(f\"Timeout while preparing instance: {str(e)}\")\n            if instance:\n                try:\n                    await instance.astop()\n                except Exception:\n                    pass\n            raise\n\n    async def _prepare_files(self, data: Dict[str, Any], temp_dir: str) -> Tuple[str, Dict[str, Any], Dict[str, str]]:\n        \"\"\"\n        Process files, determine problem ID, and prepare configuration.\n\n        Args:\n            data: Dictionary containing file information\n            temp_dir: Local temporary directory for file operations\n\n        Returns:\n            tuple: (problem_id, grader_config, local_files)\n\n        Raises:\n            ValueError: If problem ID cannot be determined\n        \"\"\"\n        # Extract problem ID\n        problem_id = None\n        graders_files = []\n        for file in data[\"files\"]:\n            if file[\"name\"].startswith(\"graders/\") and file[\"name\"].endswith(\".cpp\"):\n                potential_id = os.path.basename(file[\"name\"]).split(\".\")[0]\n                if potential_id not in [\"grader\", \"manager\", \"stub\"]:\n                    problem_id = potential_id\n\n            if file[\"name\"].startswith(\"graders/\"):\n                graders_files.append(file)\n\n        if not problem_id:\n            raise ValueError(\"Could not determine problem ID from files\")\n\n        grader_config = {\n            \"task_type\": \"Batch\",\n            \"code\": problem_id,\n            \"time_limit\": data[\"run_timeout\"] / 1000,\n            \"memory_limit\": data[\"run_memory_limit\"] * 1024 * 1024,\n        }\n\n        for file in graders_files:\n            if \"manager.cpp\" in file[\"name\"]:\n                grader_config[\"task_type\"] = \"Communication\"\n                grader_config[\"task_type_parameters_Communication_num_processes\"] = 1\n                grader_config[\"task_type_parameters_Communication_user_io\"] = \"std_io\"\n                break\n\n        config_path = os.path.join(temp_dir, \"grader_config.json\")\n        with open(config_path, \"w\") as f:\n            json.dump(grader_config, f)\n\n        local_files = {\"grader_config.json\": config_path}\n\n        for file in data[\"files\"]:\n            local_path = os.path.join(temp_dir, os.path.basename(file[\"name\"]))\n            with open(local_path, \"w\") as f:\n                f.write(file[\"content\"])\n            local_files[file[\"name\"]] = local_path\n\n        return problem_id, grader_config, local_files\n\n    async def _upload_files(self, instance: Instance, local_files: Dict[str, str]) -> bool:\n        \"\"\"\n        Upload all necessary files to the instance.\n\n        Args:\n            instance: The MorphCloud instance\n            local_files: Dictionary mapping remote paths to local file paths\n\n        Returns:\n            bool: True if all uploads were successful\n\n        Raises:\n            TimeoutError: If uploads time out\n        \"\"\"\n        for remote_name, local_path in local_files.items():\n            target_path = f\"/workspace/{remote_name}\"\n            dir_path = os.path.dirname(target_path)\n\n            if dir_path != \"/workspace\":\n                await instance.aexec(f\"mkdir -p {dir_path}\")\n\n            await instance.aupload(local_path, target_path)\n\n        await instance.aupload(local_files[\"grader_config.json\"], \"/workspace/graders/grader_config.json\")\n\n        return True\n\n    async def _compile_code(self, instance: Instance) -> InstanceExecResponse:\n        \"\"\"\n        Compile the code on the instance.\n\n        Args:\n            instance: The MorphCloud instance\n\n        Returns:\n            InstanceExecResponse: Result of compilation\n\n        Raises:\n            RuntimeError: If compilation fails\n        \"\"\"\n        compile_result = await instance.aexec(\"cd /workspace && ./compile\")\n\n        if compile_result.exit_code != 0:\n            raise RuntimeError(f\"Compilation error exit code {compile_result.exit_code}\\n{compile_result.stderr}\")\n\n        return compile_result\n\n    async def _run_tests(self, instance: Instance, data: Dict[str, Any]) -> Tuple[str, str]:\n        \"\"\"\n        Run tests and evaluate results.\n\n        Args:\n            instance: The MorphCloud instance\n            data: Dictionary containing runtime parameters\n\n        Returns:\n            tuple: (score, feedback)\n\n        Raises:\n            TimeoutError: If test execution times out\n        \"\"\"\n        hard_timeout = data[\"run_timeout\"] / 1000 + 3\n        run_command = f\"cd /workspace && timeout {hard_timeout}s ./run\"\n\n        run_result = await instance.aexec(run_command)\n\n        if run_result.exit_code == 124 or run_result.exit_code == 137 or run_result.exit_code == 143:\n            return \"0\", \"Time limit exceeded\"\n\n        if run_result.exit_code != 0 and \"Memory limit exceeded\" in run_result.stderr:\n            return \"0\", \"Memory limit exceeded\"\n\n        if run_result.stdout:\n            return run_result.stdout.strip(), run_result.stderr.strip()\n\n        if run_result.exit_code != 0:\n            return (\n                \"0\",\n                f\"Runtime error with exit code {run_result.exit_code}\\n{run_result.stderr}\",\n            )\n\n        return \"0\", \"Unknown error\"\n\n    async def _execute_with_instance(self, instance: Instance, data: Dict[str, Any], temp_dir: str) -> Tuple[str, str]:\n        \"\"\"Execute code using a prepared instance.\n\n        Args:\n            instance: Ready MorphCloud instance\n            data: Execution data\n            temp_dir: Temporary directory for file operations\n\n        Returns:\n            Tuple of (score, feedback)\n\n        Raises:\n            Exception: Passes through exceptions for retry handling\n        \"\"\"\n        await instance.await_until_ready(timeout=300)\n\n        problem_id, grader_config, local_files = await self._prepare_files(data, temp_dir)\n\n        await self._upload_files(instance, local_files)\n\n        try:\n            await self._compile_code(instance)\n        except RuntimeError as e:\n            return \"0\", str(e)\n\n        score, feedback = await self._run_tests(instance, data)\n        return score, feedback\n\n    async def _execute(self, data: Dict[str, Any]) -> Tuple[str, str]:\n        \"\"\"\n        Internal implementation of execute with no retry logic.\n\n        Args:\n            data: Dictionary containing execution data\n\n        Returns:\n            Tuple of (score, feedback)\n\n        Raises:\n            Exception: If execution fails\n        \"\"\"\n        instance = None\n\n        # Set timeouts to ensure we don't block indefinitely\n        # INSTANCE_TIMEOUT = 300  # 5 minutes for instance operations\n        TOTAL_EXECUTION_TIMEOUT = 600  # 10 minutes total execution time\n\n        with tempfile.TemporaryDirectory(prefix=\"morph_exec_\") as temp_dir:\n            snapshot = await self._get_or_create_base_snapshot()\n            instance = await self.client.instances.astart(\n                snapshot.id, ttl_seconds=600\n            )  # Auto-terminate after 10 minutes\n\n            async with instance:\n                # Use asyncio.wait_for to add overall timeout to the execution process\n                return await asyncio.wait_for(\n                    self._execute_with_instance(instance, data, temp_dir),\n                    timeout=TOTAL_EXECUTION_TIMEOUT,\n                )\n\n    async def execute(self, data: Dict[str, Any]) -> Tuple[str, str]:\n        \"\"\"\n        Execute code on MorphCloud based on the provided data with enhanced debugging and recovery.\n\n        Orchestrates the following steps with proper error handling and retries:\n        1. Prepare an instance (with retry)\n        2. Set up workspace (with retry)\n        3. Prepare and upload files (with retry)\n        4. Compile code (with retry)\n        5. Run tests (with retry)\n\n        Args:\n            data: Dictionary containing:\n                - files: List of file objects with name and content fields\n                - run_timeout: Timeout in milliseconds\n                - run_memory_limit: Memory limit in MB\n\n        Returns:\n            Tuple of (score, feedback) where:\n                - score is a string representation of a float between 0.0 and 1.0\n                - feedback is a string with execution details\n        \"\"\"\n        # TODO: would be faster to pass info about the subtask as well to create a snapshot per subtask\n        # would cache the uploads of all files other than the submission: input.txt, correct_output.txt, grader files\n        # rather than reusing the snapshot that only has the compile/run scripts on it\n        # currently, run_submission -> client.execute(data) does not easily pass subtask info\n\n        # Retry configuration\n        max_retries = 4\n        base_delay = 1.0\n\n        # Try execution with retries and exponential backoff\n        for attempt in range(max_retries + 1):\n            try:\n                return await self._execute(data)\n\n            except asyncio.TimeoutError:\n                if attempt < max_retries:\n                    print(f\"Execution timed out, retrying ({attempt + 1}/{max_retries})\")\n                else:\n                    return \"0\", \"Execution timed out after multiple retries\"\n\n            except Exception as e:\n                # Calculate exponential backoff\n                if attempt < max_retries:\n                    retry_delay = min(base_delay * (2**attempt), 30)  # Exponential backoff, capped at 30 seconds\n\n                    print(\n                        f\"Execution failed with {type(e).__name__}: {str(e)}, retrying in {retry_delay:.2f}s ({attempt + 1}/{max_retries})\"\n                    )\n                    await asyncio.sleep(retry_delay)\n                else:\n                    print(f\"Execution failed after {max_retries} retries: {type(e).__name__}: {str(e)}\")\n                    return \"0\", f\"Execution failed after multiple retries: {str(e)}\"\n\n    async def _get_or_create_base_snapshot(self):\n        \"\"\"Get or create a snapshot with the necessary dependencies and scripts for evaluation.\"\"\"\n\n        async with self._snapshot_lock:\n            base_snapshots = await self.client.snapshots.alist(digest=\"ioi-evaluation-morph\")\n\n            if not base_snapshots:\n                print(\"Creating base snapshot with build-essential cmake and g++\")\n\n                # Create base snapshot with minimal specs\n                base_snapshot = await self.client.snapshots.acreate(\n                    vcpus=2,\n                    memory=4096,\n                    disk_size=10240,\n                    metadata={\"purpose\": \"ioi_evaluation\"},\n                )\n\n                # Start a temporary instance from the base snapshot\n                temp_instance = await self.client.instances.astart(\n                    base_snapshot.id, ttl_seconds=900\n                )  # Auto-terminate after 15 minutes\n\n                try:\n                    # Wait for the instance to be ready\n                    await temp_instance.await_until_ready(timeout=300)\n\n                    # Get script contents\n                    compile_script = await self._get_compile_script()\n                    run_script = await self._get_run_script()\n\n                    # Use temporary directory to store scripts\n                    with tempfile.TemporaryDirectory(prefix=\"morph_setup_\") as temp_dir:\n                        # Create paths for script files\n                        compile_path = os.path.join(temp_dir, \"compile.sh\")\n                        run_path = os.path.join(temp_dir, \"run.sh\")\n\n                        # Write scripts to temp files\n                        with open(compile_path, \"w\") as f:\n                            f.write(compile_script)\n\n                        with open(run_path, \"w\") as f:\n                            f.write(run_script)\n\n                        async with temp_instance:\n                            # Install dependencies\n                            await temp_instance.aexec(\"apt-get update && apt-get install -y build-essential cmake g++\")\n\n                            # Create workspace directory\n                            await temp_instance.aexec(\n                                \"mkdir -p /workspace && mkdir -p /workspace/graders && chmod 777 /workspace\"\n                            )\n\n                            # Upload scripts to instance\n                            await temp_instance.aupload(compile_path, \"/workspace/compile\")\n                            await temp_instance.aupload(run_path, \"/workspace/run\")\n\n                            # Make scripts executable\n                            await temp_instance.aexec(\"chmod +x /workspace/compile /workspace/run\")\n\n                            # Create snapshot from the prepared instance\n                            final_snapshot = await temp_instance.asnapshot(digest=\"ioi-evaluation-morph\")\n\n                except Exception as e:\n                    # Ensure instance is stopped if anything fails\n                    await temp_instance.astop()\n                    raise e\n            else:\n                final_snapshot = base_snapshots[0]\n\n            return final_snapshot\n\n    async def _get_compile_script(self):\n        \"\"\"Get the compile script content.\"\"\"\n        return \"\"\"#!/bin/bash\n\nmanager_files=()  # Array to store manager filenames\ncurrent_dir=\"$(pwd)\"\n\n# Checker compilation path\nchecker_dir=\"$current_dir/checker\"\nchecker_src=\"$checker_dir/checker.cpp\"\n\nif [ -e \"$checker_src\" ]; then\n    echo \"Compiling checker\"\n    checker_exe=\"$checker_dir/checker\"\n    g++ -x c++ -std=gnu++17 -O2 -o \"$checker_exe\" \"$checker_src\"\n    chmod +x \"$checker_exe\"\n    if [ $? -ne 0 ]; then\n        echo \"Could not compile checker\" >&2\n        exit 1\n    fi\n    echo \"Compiled checker\"\nelse\n    echo \"No checker found at $checker_src\"\nfi\n\n# Graders path\ngraders_dir=\"$current_dir/graders\"\nif [ ! -e \"$graders_dir\" ]; then\n    echo \"Grader folder was not found\" >&2\n    exit 1\nfi\n\n# Find and compile manager if it exists\nmanager_src=\"$graders_dir/manager.cpp\"\nif [ -e \"$manager_src\" ]; then\n    echo \"Compiling manager\"\n    manager_exe=\"$graders_dir/manager\"\n    g++ -x c++ -std=gnu++17 -O2 -o \"$manager_exe\" \"$manager_src\"\n    chmod +x \"$manager_exe\"\n    if [ $? -ne 0 ]; then\n        echo \"Could not compile manager\" >&2\n        exit 1\n    fi\n    manager_files+=(\"manager\")\nfi\n\n# Process other graders\ngraders_list=($(ls \"$graders_dir\" | grep -v 'manager.cpp'))\nfor grader_name in \"${graders_list[@]}\"; do\n    manager_files+=(\"$grader_name\")\ndone\n\n# Extract problem name and compile necessary files\nproblem_name='?'\nfor file in \"${manager_files[@]}\"; do\n    if [[ \"$file\" == *.h && \"$file\" != \"testlib.h\" ]]; then\n        problem_name=\"${file%.h}\"\n        echo \"Problem name: $problem_name\"\n        break\n    fi\ndone\n\nfiles_to_compile=(\"graders/$problem_name.cpp\")\n[ -e graders/grader.cpp ] && files_to_compile+=(\"graders/grader.cpp\")\n[ -e graders/stub.cpp ] && files_to_compile+=(\"graders/stub.cpp\")\n\ng++ -DEVAL -std=gnu++17 -O2 -pipe -s -o graders/\"$problem_name\" \"${files_to_compile[@]}\"\nif [ $? -ne 0 ]; then\n    echo \"Failed to compile $problem_name\" >&2\n    exit 1\nfi\nchmod +x graders/\"$problem_name\"\necho \"Compiled $problem_name from ${files_to_compile[@]} successfully\"\n\necho \"Manager files: ${manager_files[@]}\"\n\"\"\"\n\n    async def _get_run_script(self):\n        \"\"\"Get the run script content.\"\"\"\n        return \"\"\"#!/usr/bin/env bash\n# disable stack limit so you don't get RE with recursion\nulimit -s unlimited\n# some problems have 10MB+ input/output files in their test cases and you might get RE. uncomment if needed\n# ulimit -f 2097152\n\n# Check if grader_config.json exists\nif [ ! -f \"graders/grader_config.json\" ]; then\n    echo \"Error: graders/grader_config.json not found\" >&2\n    echo \"Current directory contents:\" >&2\n    find . -type f -o -type d | sed -e 's/[^-][^\\/]*\\//  |/g' -e 's/|\\([^ ]\\)/|-\\1/' >&2\n    exit 1\nfi\n\n# Read task type, code, and time limit from grader_config.json using grep and sed\nTASK_TYPE=$(grep -o '\"task_type\":[^,}]*' graders/grader_config.json | sed 's/\"task_type\":\\\\s*\"\\\\([^\"]*\\\\)\"/\\\\1/')\nTASK_NAME=$(grep -o '\"code\":[^,}]*' graders/grader_config.json | sed 's/\"code\":\\\\s*\"\\\\([^\"]*\\\\)\"/\\\\1/')\nTIME_LIMIT=$(grep -o '\"time_limit\":[^,}]*' graders/grader_config.json | sed 's/\"time_limit\":\\\\s*\\\\([^,}]*\\\\)/\\\\1/')\nMEMORY_LIMIT=$(grep -o '\"memory_limit\":[^,}]*' graders/grader_config.json | sed 's/\"memory_limit\":\\\\s*\\\\([^,}]*\\\\)/\\\\1/')\nTASK_EXECUTABLE=\"graders/$TASK_NAME\"\n\n# Set memory limit in KB (convert from bytes)\nMEMORY_LIMIT_KB=0\nif [ -n \"$MEMORY_LIMIT\" ]; then\n    MEMORY_LIMIT_KB=$(($MEMORY_LIMIT / 1024))\n    # Set the memory limit for the entire script and all child processes\n    ulimit -v $MEMORY_LIMIT_KB\nfi\n\n# \"Securely\" handle the correct output file\nCORRECT_OUTPUT=\"\"\nif [ -f \"correct_output.txt\" ]; then\n    # Read the content and immediately remove the file\n    CORRECT_OUTPUT=$(cat correct_output.txt)\n    rm -f correct_output.txt\nfi\n\n# Create a temporary file for solution output\nSOLUTION_OUTPUT=$(mktemp)\n\n# Global variables for process tracking\ndeclare -a ALL_PIDS\ndeclare -a FIFO_DIRS\n\n# Define cleanup function - simplified assuming timeout exists\nfunction cleanup {\n    # Kill all tracked processes silently\n    exec 2>/dev/null\n    for pid in \"${ALL_PIDS[@]:-}\"; do\n        kill -9 \"$pid\" 2>/dev/null || true\n    done\n\n    # Clean up FIFO directories\n    for dir in \"${FIFO_DIRS[@]:-}\"; do\n        [ -d \"$dir\" ] && rm -rf \"$dir\"\n    done\n\n    # Clean up temporary files\n    rm -f \"$SOLUTION_OUTPUT\" || true\n    exec 2>&2\n}\n\n# Set up signal handling\ntrap cleanup EXIT INT TERM\n\n# Function to handle exit codes consistently across task types\nfunction handle_exit_code {\n    local exit_code=$1\n\n    # Check for known timeout exit codes:\n    # - 124: standard timeout exit code\n    # - 137: SIGKILL (128+9), used for hard timeouts\n    # - 143: SIGTERM (128+15), can also be used for timeouts\n    if [ $exit_code -eq 124 ] || [ $exit_code -eq 137 ] || [ $exit_code -eq 143 ]; then\n        echo \"0\"\n        echo \"Time limit exceeded (${TIME_LIMIT}s)\" >&2\n        return 124\n    # All other non-zero exit codes should be treated as runtime errors\n    elif [ $exit_code -ne 0 ]; then\n        echo \"0\"\n        echo \"Runtime error with exit code $exit_code\" >&2\n        return $exit_code\n    fi\n\n    # Success case - return 0\n    return 0\n}\n\n# Function to run a command with timeout (simplified assuming timeout exists)\nfunction run_with_timeout {\n    local soft_limit=$1; shift\n    local command_to_run=\"$@\"\n\n    timeout --preserve-status \"$soft_limit\" \"$@\"\n    return $?\n}\n\ncase \"$TASK_TYPE\" in\n    \"Batch\")\n        # Simple batch execution with timeout\n        run_with_timeout \"$TIME_LIMIT\" ./$TASK_EXECUTABLE < input.txt > \"$SOLUTION_OUTPUT\"\n        exit_code=$?\n\n        # Handle non-zero exit codes\n        handle_exit_code $exit_code\n        if [ $? -ne 0 ]; then\n            exit $?\n        fi\n\n        # Check the output if we have a correct output\n        if [ -n \"$CORRECT_OUTPUT\" ]; then\n            # Restore the correct output file\n            echo \"$CORRECT_OUTPUT\" > correct_output.txt\n\n            # Check if there's a custom checker\n            if [ -f \"checker/checker\" ]; then\n                # Let the checker handle everything\n                ./checker/checker input.txt correct_output.txt \"$SOLUTION_OUTPUT\"\n                exit $?\n            else\n                # Simple diff-based checking\n                if diff -bq <(echo \"$CORRECT_OUTPUT\") \"$SOLUTION_OUTPUT\" >/dev/null; then\n                    echo \"1\"\n                    echo \"Output is correct (diff)\" >&2\n                else\n                    echo \"0\"\n                    echo \"Output isn't correct (diff)\" >&2\n                    exit 0\n                fi\n            fi\n        else\n            # If no correct output was provided, just output the solution's output\n            cat \"$SOLUTION_OUTPUT\"\n        fi\n        ;;\n\n    \"Communication\")\n        # Read Communication-specific parameters\n        NUM_PROCESSES=$(grep -o '\"task_type_parameters_Communication_num_processes\":[^,}]*' graders/grader_config.json | sed 's/.*:\\\\s*\\\\([0-9]*\\\\)/\\\\1/' || true)\n        if [ -z \"$NUM_PROCESSES\" ]; then\n            NUM_PROCESSES=1\n        fi\n        USER_IO=$(grep -o '\"task_type_parameters_Communication_user_io\":[^,}]*' graders/grader_config.json | sed 's/.*:\\\\s*\"\\\\([^\"]*\\\\)\"/\\\\1/' || echo \"std_io\")\n\n        # Read custom manager arguments if they exist\n        MANAGER_CUSTOM_ARGS=\"\"\n        if grep -q '\"task_type_parameters_Communication_manager_args\"' graders/grader_config.json; then\n            MANAGER_CUSTOM_ARGS=$(grep -o '\"task_type_parameters_Communication_manager_args\":[^,}]*' graders/grader_config.json | sed 's/.*:\\\\s*\"\\\\([^\"]*\\\\)\"/\\\\1/')\n        fi\n\n        # Create temporary directories for FIFOs\n        for i in $(seq 0 $((NUM_PROCESSES-1))); do\n            FIFO_DIRS[$i]=$(mktemp -d)\n\n            # Create FIFOs for this process\n            mkfifo \"${FIFO_DIRS[$i]}/u${i}_to_m\"\n            mkfifo \"${FIFO_DIRS[$i]}/m_to_u${i}\"\n            chmod 755 \"${FIFO_DIRS[$i]}\"\n            chmod 666 \"${FIFO_DIRS[$i]}/u${i}_to_m\" \"${FIFO_DIRS[$i]}/m_to_u${i}\"\n        done\n\n        # Prepare manager arguments\n        MANAGER_ARGS=\"\"\n        for i in $(seq 0 $((NUM_PROCESSES-1))); do\n            MANAGER_ARGS=\"$MANAGER_ARGS ${FIFO_DIRS[$i]}/u${i}_to_m ${FIFO_DIRS[$i]}/m_to_u${i}\"\n        done\n\n        # Add custom manager arguments if specified\n        if [ -n \"$MANAGER_CUSTOM_ARGS\" ]; then\n            MANAGER_ARGS=\"$MANAGER_ARGS $MANAGER_CUSTOM_ARGS\"\n        fi\n\n        # Start all user processes first\n        for i in $(seq 0 $((NUM_PROCESSES-1))); do\n            if [ \"$USER_IO\" = \"fifo_io\" ]; then\n                # Pass FIFOs as arguments\n                ARGS=\"${FIFO_DIRS[$i]}/m_to_u${i} ${FIFO_DIRS[$i]}/u${i}_to_m\"\n                if [ \"$NUM_PROCESSES\" -ne 1 ]; then\n                    ARGS=\"$ARGS $i\"\n                fi\n                ./$TASK_EXECUTABLE $ARGS &\n                ALL_PIDS+=($!)\n            else\n                # Use stdin/stdout redirection\n                if [ \"$NUM_PROCESSES\" -ne 1 ]; then\n                    ./$TASK_EXECUTABLE \"$i\" < \"${FIFO_DIRS[$i]}/m_to_u${i}\" > \"${FIFO_DIRS[$i]}/u${i}_to_m\" 2>/dev/null &\n                    ALL_PIDS+=($!)\n                else\n                    ./$TASK_EXECUTABLE < \"${FIFO_DIRS[$i]}/m_to_u${i}\" > \"${FIFO_DIRS[$i]}/u${i}_to_m\" 2>/dev/null &\n                    ALL_PIDS+=($!)\n                fi\n            fi\n        done\n\n        # Run the manager with timeout using direct pipe from input.txt\n        run_with_timeout \"$TIME_LIMIT\" ./graders/manager $MANAGER_ARGS < input.txt > \"$SOLUTION_OUTPUT\"\n\n        exit_code=$?\n\n        # Handle non-zero exit codes\n        handle_exit_code $exit_code\n        if [ $? -ne 0 ]; then\n            exit $?\n        fi\n\n        # Check the output if we have a correct output AND there's a checker (otherwise we assume the manager handles everything)\n        if [ -n \"$CORRECT_OUTPUT\" ] && [ -f \"checker/checker\" ]; then\n            # Restore the correct output file\n            echo \"$CORRECT_OUTPUT\" > correct_output.txt\n\n            # Let the checker handle it\n            ./checker/checker input.txt correct_output.txt \"$SOLUTION_OUTPUT\"\n            exit $?\n        else\n            # we assume the manager handles it\n            cat \"$SOLUTION_OUTPUT\"\n        fi\n        ;;\n\n    *)\n        echo \"0\"\n        echo \"Unsupported task type \\\"$TASK_TYPE\\\"\" >&2\n        exit 1\n        ;;\nesac\n\"\"\"\n\n\ndef get_morph_client_from_env(session=None) -> MorphCloudExecutionClient:\n    \"\"\"\n    Creates a MorphCloudExecutionClient instance using environment variables.\n\n    Environment variables:\n        MORPH_API_KEY: API key for MorphCloud\n\n    Args:\n        session: Optional aiohttp.ClientSession to use for HTTP requests\n\n    Returns:\n        MorphCloudExecutionClient: A configured MorphCloud execution client\n    \"\"\"\n    if not is_morph_available():\n        raise ImportError(\n            \"MorphCloud is not available and required for this function. Please install MorphCloud with \"\n            \"`pip install morphcloud` and add an API key to a `.env` file.\"\n        )\n\n    load_dotenv()\n    api_key = os.environ.get(\"MORPH_API_KEY\")\n    if not api_key:\n        raise ValueError(\"MORPH_API_KEY environment variable is required\")\n\n    return MorphCloudExecutionClient(api_key=api_key)\n\n\n# noqa: W293\n"
  },
  {
    "path": "src/open_r1/utils/competitive_programming/piston_client.py",
    "content": "import asyncio\nimport os\nimport random\nimport re\nimport subprocess\nfrom collections import Counter\nfrom functools import lru_cache\n\nimport aiohttp\n\n\nclass PistonError(Exception):\n    pass\n\n\n@lru_cache(maxsize=1)\ndef get_piston_client_from_env(session=None):\n    piston_endpoints = os.getenv(\"PISTON_ENDPOINTS\")\n    if piston_endpoints is None:\n        raise ValueError(\n            \"For IOI/CF problems Piston endpoints running our IOI package are required. Please add a list of valid Piston endpoints to a PISTON_ENDPOINTS variable in a `.env` file.\"\n        )\n    piston_endpoints = sorted(\n        piston_endpoints.split(\",\") if piston_endpoints != \"slurm\" else get_slurm_piston_endpoints()\n    )\n    gpu_nb = int(os.getenv(\"LOCAL_RANK\", 0))  # per‑GPU index\n    world = int(os.getenv(\"WORLD_SIZE\", 1))  # total GPUs\n    if world > 1:\n        print(f\"Using a subset of piston endpoints for GPU#{gpu_nb}\")\n        piston_endpoints = piston_endpoints[gpu_nb::world]\n    random.shuffle(piston_endpoints)\n    max_requests_per_endpoint = os.getenv(\"PISTON_MAX_REQUESTS_PER_ENDPOINT\", \"1\")\n    return PistonClient(piston_endpoints, session, max_requests_per_endpoint=int(max_requests_per_endpoint))\n\n\nclass PistonClient:\n    \"\"\"\n    A client that will automatically load balance across multiple Piston (https://github.com/engineer-man/piston) workers.\n    This assumes piston is running our custom cms_ioi package: https://github.com/guipenedo/piston/releases/\n    We recommend starting the instances with the following script as otherwise some IOI problems will hit default limits:\n    ```\n    export PISTON_COMPILE_TIMEOUT=60000\n    export PISTON_RUN_TIMEOUT=60000\n    export PISTON_OUTPUT_MAX_SIZE=1000000000\n    export PISTON_MAX_FILE_SIZE=1000000000\n    export PISTON_DISABLE_NETWORKING=true\n    export PISTON_REPO_URL=https://github.com/guipenedo/piston/releases/download/pkgs/index\n    mkdir /piston\n\n    sed -i '/app.use(body_parser.urlencoded/c\\    app.use(body_parser.urlencoded({ extended: true, limit: \\\"512mb\\\" }));' src/index.js\n    sed -i '/app.use(body_parser.json/c\\    app.use(body_parser.json({ limit: \\\"512mb\\\" }));' src/index.js\n\n    # Start server in background\n    node src```\n\n    Piston docs for API usage: https://piston.readthedocs.io/en/latest/api-v2/\n    \"\"\"\n\n    def __init__(\n        self,\n        base_endpoint: str | list[str] = \"http://ip-10-53-80-65:3223/api/v2\",\n        session=None,\n        max_requests_per_endpoint=1,\n    ):\n        self.max_requests_per_endpoint = max_requests_per_endpoint\n        self.base_endpoints = [base_endpoint] if isinstance(base_endpoint, str) else base_endpoint\n        if len(self.base_endpoints) == 0:\n            raise ValueError(\"No Piston endpoints provided. Please check your PISTON_ENDPOINTS environment variable.\")\n        self.endpoint_ids = {endpoint: i for i, endpoint in enumerate(self.base_endpoints)}\n\n        self._session = session\n        self.endpoint_tokens = asyncio.Queue(maxsize=max_requests_per_endpoint * len(self.base_endpoints))\n\n        for _ in range(max_requests_per_endpoint):\n            for base_endpoint in self.base_endpoints:\n                self.endpoint_tokens.put_nowait(base_endpoint)\n        self._endpoint_failures = Counter()\n        self._unhealthy_endpoints = set()\n        self._endpoint_failures_lock = asyncio.Lock()\n\n    @property\n    def session(self):\n        if self._session is None:\n            self._session = aiohttp.ClientSession(\n                timeout=aiohttp.ClientTimeout(sock_read=30),\n                connector=aiohttp.TCPConnector(\n                    limit=self.max_requests_per_endpoint * len(self.base_endpoints),\n                    ttl_dns_cache=300,\n                    keepalive_timeout=5 * 60,\n                ),\n            )\n        return self._session\n\n    async def _wait_for_endpoint(self):\n        endpoint = await self.endpoint_tokens.get()\n        return endpoint\n\n    async def _release_endpoint(self, endpoint):\n        await self.endpoint_tokens.put(endpoint)\n\n    async def _send_request(self, endpoint, route, data=None, method=\"post\"):\n        async with self.session.request(\n            method, f\"{endpoint.rstrip('/')}/{route}\", json=data, headers={\"Content-Type\": \"application/json\"}\n        ) as response:\n            return await response.json(content_type=None)\n\n    async def _send_to_all(self, route, data=None, method=\"post\"):\n        return await asyncio.gather(\n            *[self._send_request(endpoint, route, data, method) for endpoint in self.base_endpoints]\n        )\n\n    async def _send_to_one(self, endpoint, route, data=None, method=\"post\"):\n        return await self._send_request(endpoint, route, data, method)\n\n    async def install_package(self, language, version):\n        return await self._send_to_all(\"packages\", {\"language\": language, \"version\": version}, method=\"post\")\n\n    async def uninstall_package(self, language, version):\n        return await self._send_to_all(\"packages\", {\"language\": language, \"version\": version}, method=\"delete\")\n\n    async def get_supported_runtimes(self):\n        return await self._send_to_all(\"runtimes\", method=\"get\")\n\n    async def _check_failed_endpoint(self, endpoint):\n        async with self._endpoint_failures_lock:\n            if endpoint in self._unhealthy_endpoints:\n                return\n            try:\n                await asyncio.sleep(5)\n                await self.get_supported_runtimes()\n            except Exception as e:\n                print(f\"Error checking endpoint {endpoint}, dropping it ({e})\")\n                self._unhealthy_endpoints.add(endpoint)\n                if len(self._unhealthy_endpoints) >= len(self.base_endpoints):\n                    raise PistonError(\"All endpoints are unhealthy. Please check your Piston workers.\")\n\n    async def send_execute(self, data, language=\"cms_ioi\", max_retries=5):\n        data = data | {\n            \"language\": language,\n            \"version\": \"*\",\n        }\n\n        base_delay = 1.0\n\n        status = None\n        endpoint = None\n\n        for attempt in range(max_retries + 1):\n            try:\n                endpoint = await self._wait_for_endpoint()\n                if attempt > 0:\n                    await asyncio.sleep(1)\n                async with self.session.post(\n                    f\"{endpoint.rstrip('/')}/execute\", json=data, headers={\"Content-Type\": \"application/json\"}\n                ) as response:\n                    status = response.status\n                    res_json = await response.json(content_type=None)\n\n                    if status != 200:\n                        raise PistonError(f\"Server error. status={status}. {res_json}\")\n                    if res_json is None:\n                        raise PistonError(f\"Empty response. status={status}\")\n                    # piston overloaded\n                    if \"run\" in res_json and \"Resource temporarily unavailable\" in res_json[\"run\"].get(\"stderr\", \"\"):\n                        raise PistonError(f\"Piston overloaded: {res_json['run']['stderr']}\")\n                    return res_json\n\n            except (PistonError, asyncio.TimeoutError, aiohttp.ClientConnectionError, RuntimeError) as e:\n                # Only retry if we haven't reached max retries yet\n                if attempt < max_retries:\n                    # Calculate backoff with jitter\n                    delay = min(base_delay * (2**attempt), 10)  # Exponential backoff, capped at 10 seconds\n                    jitter = delay * 0.2 * (2 * asyncio.get_event_loop().time() % 1 - 0.5)  # Add ±10% jitter\n                    retry_delay = delay + jitter\n                    print(f\"Retrying in {retry_delay:.2f} seconds [{self.endpoint_ids[endpoint]}] {endpoint} - {e}\")\n\n                    # special case: worker died\n                    if isinstance(e, aiohttp.ClientConnectionError) and \"Connect call failed\" in str(e):\n                        await self._check_failed_endpoint(endpoint)\n                    else:\n                        # hopefully we won't get this one again\n                        await self._release_endpoint(endpoint)\n                    endpoint = None\n\n                    await asyncio.sleep(retry_delay)\n                else:\n                    await self._check_failed_endpoint(endpoint)\n            except Exception as e:\n                print(f\"Propagating exception {type(e)}: {e}\")\n                raise e\n            finally:\n                # Ensure endpoint is always released, even if an exception occurs\n                if endpoint is not None:\n                    try:\n                        await self._release_endpoint(endpoint)\n                    except Exception as e:\n                        print(f\"Error releasing endpoint {endpoint}: {e}\")\n                    endpoint = None\n\n\ndef get_slurm_piston_endpoints():\n    \"\"\"Get list of active piston worker endpoints from squeue output\"\"\"\n    # Run squeue command to get job name, hostname and status, filtering for RUNNING state\n    result = subprocess.run(\n        [\"squeue\", '--format=\"%j %N %T\"', \"--noheader\", \"--states=RUNNING\"], capture_output=True, text=True\n    )\n\n    # Split output into lines and skip header\n    lines = result.stdout.strip().split(\"\\n\")\n\n    endpoints = []\n    for line in lines:\n        # Parse job name from squeue output\n        fields = line.split()\n        job_name = fields[0].strip('\"')  # Remove quotes\n        hostname = fields[1]\n\n        # Extract port if job name matches pattern\n        match = re.match(r\"piston-worker-(\\d+)\", job_name)\n        if match:\n            port = match.group(1)\n            endpoints.append(f\"http://{hostname}:{port}/api/v2\")\n\n    return endpoints\n"
  },
  {
    "path": "src/open_r1/utils/competitive_programming/utils.py",
    "content": "from itertools import islice\n\n\ndef batched(iterable, n):\n    \"Batch data into lists of length n. The last batch may be shorter.\"\n    # batched('ABCDEFG', 3) --> ABC DEF G\n    if n < 1:\n        return iterable\n    it = iter(iterable)\n    while batch := list(islice(it, n)):\n        yield batch\n"
  },
  {
    "path": "src/open_r1/utils/data.py",
    "content": "import logging\n\nimport datasets\nfrom datasets import DatasetDict, concatenate_datasets\n\nfrom ..configs import ScriptArguments\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_dataset(args: ScriptArguments) -> DatasetDict:\n    \"\"\"Load a dataset or a mixture of datasets based on the configuration.\n\n    Args:\n        args (ScriptArguments): Script arguments containing dataset configuration.\n\n    Returns:\n        DatasetDict: The loaded datasets.\n    \"\"\"\n    if args.dataset_name and not args.dataset_mixture:\n        logger.info(f\"Loading dataset: {args.dataset_name}\")\n        return datasets.load_dataset(args.dataset_name, args.dataset_config)\n    elif args.dataset_mixture:\n        logger.info(f\"Creating dataset mixture with {len(args.dataset_mixture.datasets)} datasets\")\n        seed = args.dataset_mixture.seed\n        datasets_list = []\n\n        for dataset_config in args.dataset_mixture.datasets:\n            logger.info(f\"Loading dataset for mixture: {dataset_config.id} (config: {dataset_config.config})\")\n            ds = datasets.load_dataset(\n                dataset_config.id,\n                dataset_config.config,\n                split=dataset_config.split,\n            )\n            if dataset_config.columns is not None:\n                ds = ds.select_columns(dataset_config.columns)\n            if dataset_config.weight is not None:\n                ds = ds.shuffle(seed=seed).select(range(int(len(ds) * dataset_config.weight)))\n                logger.info(\n                    f\"Subsampled dataset '{dataset_config.id}' (config: {dataset_config.config}) with weight={dataset_config.weight} to {len(ds)} examples\"\n                )\n\n            datasets_list.append(ds)\n\n        if datasets_list:\n            combined_dataset = concatenate_datasets(datasets_list)\n            combined_dataset = combined_dataset.shuffle(seed=seed)\n            logger.info(f\"Created dataset mixture with {len(combined_dataset)} examples\")\n\n            if args.dataset_mixture.test_split_size is not None:\n                combined_dataset = combined_dataset.train_test_split(\n                    test_size=args.dataset_mixture.test_split_size, seed=seed\n                )\n                logger.info(\n                    f\"Split dataset into train and test sets with test size: {args.dataset_mixture.test_split_size}\"\n                )\n                return combined_dataset\n            else:\n                return DatasetDict({\"train\": combined_dataset})\n        else:\n            raise ValueError(\"No datasets were loaded from the mixture configuration\")\n\n    else:\n        raise ValueError(\"Either `dataset_name` or `dataset_mixture` must be provided\")\n"
  },
  {
    "path": "src/open_r1/utils/evaluation.py",
    "content": "import subprocess\nfrom typing import TYPE_CHECKING, Dict, Union\n\nfrom .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id\n\n\nif TYPE_CHECKING:\n    from trl import GRPOConfig, SFTConfig, ModelConfig\n\nimport base64\nimport os\n\n\n# We need a special environment setup to launch vLLM from within Slurm training jobs.\n# - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105\n# - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269\nuser_home_directory = os.path.expanduser(\"~\")\nVLLM_SLURM_PREFIX = [\n    \"env\",\n    \"-i\",\n    \"bash\",\n    \"-c\",\n    f\"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch \",\n]\n\n\ndef register_lighteval_task(\n    configs: Dict[str, str],\n    eval_suite: str,\n    task_name: str,\n    task_list: str,\n    num_fewshot: int = 0,\n):\n    \"\"\"Registers a LightEval task configuration.\n\n    - Core tasks can be added from this table: https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl\n    - Custom tasks that require their own metrics / scripts, should be stored in scripts/evaluation/extended_lighteval_tasks\n\n    Args:\n        configs (Dict[str, str]): The dictionary to store the task configuration.\n        eval_suite (str, optional): The evaluation suite.\n        task_name (str): The name of the task.\n        task_list (str): The comma-separated list of tasks in the format \"extended|{task_name}|{num_fewshot}|0\" or \"lighteval|{task_name}|{num_fewshot}|0\".\n        num_fewshot (int, optional): The number of few-shot examples. Defaults to 0.\n        is_custom_task (bool, optional): Whether the task is a custom task. Defaults to False.\n    \"\"\"\n    # Format task list in lighteval format\n    task_list = \",\".join(f\"{eval_suite}|{task}|{num_fewshot}|0\" for task in task_list.split(\",\"))\n    configs[task_name] = task_list\n\n\nLIGHTEVAL_TASKS = {}\n\nregister_lighteval_task(LIGHTEVAL_TASKS, \"lighteval\", \"math_500\", \"math_500\", 0)\nregister_lighteval_task(LIGHTEVAL_TASKS, \"lighteval\", \"aime24\", \"aime24\", 0)\nregister_lighteval_task(LIGHTEVAL_TASKS, \"lighteval\", \"aime25\", \"aime25\", 0)\nregister_lighteval_task(LIGHTEVAL_TASKS, \"lighteval\", \"gpqa\", \"gpqa:diamond\", 0)\nregister_lighteval_task(LIGHTEVAL_TASKS, \"extended\", \"lcb\", \"lcb:codegeneration\", 0)\nregister_lighteval_task(LIGHTEVAL_TASKS, \"extended\", \"lcb_v4\", \"lcb:codegeneration_v4\", 0)\n\n\ndef get_lighteval_tasks():\n    return list(LIGHTEVAL_TASKS.keys())\n\n\nSUPPORTED_BENCHMARKS = get_lighteval_tasks()\n\n\ndef run_lighteval_job(\n    benchmark: str,\n    training_args: Union[\"SFTConfig\", \"GRPOConfig\"],\n    model_args: \"ModelConfig\",\n) -> None:\n    task_list = LIGHTEVAL_TASKS[benchmark]\n    model_name = training_args.hub_model_id\n    model_revision = training_args.hub_model_revision\n    # For large models >= 30b params or those running the MATH benchmark, we need to shard them across the GPUs to avoid OOM\n    num_gpus = get_gpu_count_for_vllm(model_name, model_revision)\n    if get_param_count_from_repo_id(model_name) >= 30_000_000_000:\n        tensor_parallel = True\n    else:\n        num_gpus = 2  # Hack while cluster is full\n        tensor_parallel = False\n\n    cmd = VLLM_SLURM_PREFIX.copy()\n    cmd_args = [\n        f\"--gres=gpu:{num_gpus}\",\n        f\"--job-name=or1_{benchmark}_{model_name.split('/')[-1]}_{model_revision}\",\n        \"slurm/evaluate.slurm\",\n        benchmark,\n        f'\"{task_list}\"',\n        model_name,\n        model_revision,\n        f\"{tensor_parallel}\",\n        f\"{model_args.trust_remote_code}\",\n    ]\n    if training_args.system_prompt is not None:\n        # encode to base64 to avoid issues with special characters\n        # we decode in the sbatch script\n        prompt_encoded = base64.b64encode(training_args.system_prompt.encode()).decode()\n        cmd_args.append(prompt_encoded)\n    cmd[-1] += \" \" + \" \".join(cmd_args)\n    subprocess.run(cmd, check=True)\n\n\ndef run_benchmark_jobs(training_args: Union[\"SFTConfig\", \"GRPOConfig\"], model_args: \"ModelConfig\") -> None:\n    benchmarks = training_args.benchmarks\n    if len(benchmarks) == 1 and benchmarks[0] == \"all\":\n        benchmarks = get_lighteval_tasks()\n        # Evaluate on all supported benchmarks. Later we may want to include a `chat` option\n        # that just evaluates on `ifeval` and `mt_bench` etc.\n\n    for benchmark in benchmarks:\n        print(f\"Launching benchmark `{benchmark}`\")\n        if benchmark in get_lighteval_tasks():\n            run_lighteval_job(benchmark, training_args, model_args)\n        else:\n            raise ValueError(f\"Unknown benchmark {benchmark}\")\n"
  },
  {
    "path": "src/open_r1/utils/hub.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport re\nfrom concurrent.futures import Future\n\nfrom transformers import AutoConfig\n\nfrom huggingface_hub import (\n    create_branch,\n    create_repo,\n    get_safetensors_metadata,\n    list_repo_commits,\n    list_repo_files,\n    list_repo_refs,\n    repo_exists,\n    upload_folder,\n)\nfrom trl import GRPOConfig, SFTConfig\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef push_to_hub_revision(training_args: SFTConfig | GRPOConfig, extra_ignore_patterns=[]) -> Future:\n    \"\"\"Pushes the model to branch on a Hub repo.\"\"\"\n\n    # Create a repo if it doesn't exist yet\n    repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True)\n    # Get initial commit to branch from\n    initial_commit = list_repo_commits(training_args.hub_model_id)[-1]\n    # Now create the branch we'll be pushing to\n    create_branch(\n        repo_id=training_args.hub_model_id,\n        branch=training_args.hub_model_revision,\n        revision=initial_commit.commit_id,\n        exist_ok=True,\n    )\n    logger.info(f\"Created target repo at {repo_url}\")\n    logger.info(f\"Pushing to the Hub revision {training_args.hub_model_revision}...\")\n    ignore_patterns = [\"checkpoint-*\", \"*.pth\"]\n    ignore_patterns.extend(extra_ignore_patterns)\n    future = upload_folder(\n        repo_id=training_args.hub_model_id,\n        folder_path=training_args.output_dir,\n        revision=training_args.hub_model_revision,\n        commit_message=f\"Add {training_args.hub_model_revision} checkpoint\",\n        ignore_patterns=ignore_patterns,\n        run_as_future=True,\n    )\n    logger.info(f\"Pushed to {repo_url} revision {training_args.hub_model_revision} successfully!\")\n\n    return future\n\n\ndef check_hub_revision_exists(training_args: SFTConfig | GRPOConfig):\n    \"\"\"Checks if a given Hub revision exists.\"\"\"\n    if repo_exists(training_args.hub_model_id):\n        if training_args.push_to_hub_revision is True:\n            # First check if the revision exists\n            revisions = [rev.name for rev in list_repo_refs(training_args.hub_model_id).branches]\n            # If the revision exists, we next check it has a README file\n            if training_args.hub_model_revision in revisions:\n                repo_files = list_repo_files(\n                    repo_id=training_args.hub_model_id,\n                    revision=training_args.hub_model_revision,\n                )\n                if \"README.md\" in repo_files and training_args.overwrite_hub_revision is False:\n                    raise ValueError(\n                        f\"Revision {training_args.hub_model_revision} already exists. \"\n                        \"Use --overwrite_hub_revision to overwrite it.\"\n                    )\n\n\ndef get_param_count_from_repo_id(repo_id: str) -> int:\n    \"\"\"Function to get model param counts from safetensors metadata or find patterns like 42m, 1.5b, 0.5m or products like 8x7b in a repo ID.\"\"\"\n    try:\n        metadata = get_safetensors_metadata(repo_id)\n        return list(metadata.parameter_count.values())[0]\n    except Exception:\n        # Pattern to match products (like 8x7b) and single values (like 42m)\n        pattern = r\"((\\d+(\\.\\d+)?)(x(\\d+(\\.\\d+)?))?)([bm])\"\n        matches = re.findall(pattern, repo_id.lower())\n\n        param_counts = []\n        for full_match, number1, _, _, number2, _, unit in matches:\n            if number2:  # If there's a second number, it's a product\n                number = float(number1) * float(number2)\n            else:  # Otherwise, it's a single value\n                number = float(number1)\n\n            if unit == \"b\":\n                number *= 1_000_000_000  # Convert to billion\n            elif unit == \"m\":\n                number *= 1_000_000  # Convert to million\n\n            param_counts.append(number)\n\n        if len(param_counts) > 0:\n            # Return the largest number\n            return int(max(param_counts))\n        else:\n            # Return -1 if no match found\n            return -1\n\n\ndef get_gpu_count_for_vllm(model_name: str, revision: str = \"main\", num_gpus: int = 8) -> int:\n    \"\"\"vLLM enforces a constraint that the number of attention heads must be divisible by the number of GPUs and 64 must be divisible by the number of GPUs.\n    This function calculates the number of GPUs to use for decoding based on the number of attention heads in the model.\n    \"\"\"\n    config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)\n    # Get number of attention heads\n    num_heads = config.num_attention_heads\n    # Reduce num_gpus so that num_heads is divisible by num_gpus and 64 is divisible by num_gpus\n    while num_heads % num_gpus != 0 or 64 % num_gpus != 0:\n        logger.info(f\"Reducing num_gpus from {num_gpus} to {num_gpus - 1} to make num_heads divisible by num_gpus\")\n        num_gpus -= 1\n    return num_gpus\n"
  },
  {
    "path": "src/open_r1/utils/import_utils.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers.utils.import_utils import _is_package_available\n\n\n# Use same as transformers.utils.import_utils\n_e2b_available = _is_package_available(\"e2b\")\n\n\ndef is_e2b_available() -> bool:\n    return _e2b_available\n\n\n_morph_available = _is_package_available(\"morphcloud\")\n\n\ndef is_morph_available() -> bool:\n    return _morph_available\n"
  },
  {
    "path": "src/open_r1/utils/model_utils.py",
    "content": "import torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer\n\nfrom trl import ModelConfig, get_kbit_device_map, get_quantization_config\n\nfrom ..configs import GRPOConfig, SFTConfig\n\n\ndef get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer:\n    \"\"\"Get the tokenizer for the model.\"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(\n        model_args.model_name_or_path,\n        revision=model_args.model_revision,\n        trust_remote_code=model_args.trust_remote_code,\n    )\n\n    if training_args.chat_template is not None:\n        tokenizer.chat_template = training_args.chat_template\n\n    return tokenizer\n\n\ndef get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:\n    \"\"\"Get the model\"\"\"\n    torch_dtype = (\n        model_args.torch_dtype if model_args.torch_dtype in [\"auto\", None] else getattr(torch, model_args.torch_dtype)\n    )\n    quantization_config = get_quantization_config(model_args)\n    model_kwargs = dict(\n        revision=model_args.model_revision,\n        trust_remote_code=model_args.trust_remote_code,\n        attn_implementation=model_args.attn_implementation,\n        torch_dtype=torch_dtype,\n        use_cache=False if training_args.gradient_checkpointing else True,\n        device_map=get_kbit_device_map() if quantization_config is not None else None,\n        quantization_config=quantization_config,\n    )\n    model = AutoModelForCausalLM.from_pretrained(\n        model_args.model_name_or_path,\n        **model_kwargs,\n    )\n    return model\n"
  },
  {
    "path": "src/open_r1/utils/routed_morph.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional\n\nimport requests\n\n\nclass RoutedMorphSandbox:\n    \"\"\"\n    Client for the MorphCloud router service that mimics the API of MorphCloud's Sandbox.\n\n    This class provides a simple interface to execute code via a central MorphCloud router,\n    which manages sandbox creation and cleanup. It allows batch processing of multiple scripts\n    in a single request for improved efficiency.\n\n    Attributes:\n        router_url (str): The URL of the MorphCloud router service.\n        timeout (int): Execution timeout in seconds.\n        request_timeout (int): HTTP request timeout in seconds.\n    \"\"\"\n\n    def __init__(self, router_url: str, timeout: int = 300, request_timeout: int = 60):\n        \"\"\"\n        Initialize the routed MorphCloud sandbox client.\n\n        Args:\n            router_url: The URL of the MorphCloud router, including host and port.\n            timeout: Default execution timeout in seconds.\n            request_timeout: Default HTTP request timeout in seconds.\n        \"\"\"\n        self.router_url = router_url\n        self.timeout = timeout\n        self.request_timeout = request_timeout\n\n    def run_code(\n        self,\n        scripts: List[str],\n        languages: Optional[List[str]] = None,\n        timeout: Optional[int] = None,\n        request_timeout: Optional[int] = None,\n    ) -> List:\n        \"\"\"\n        Execute multiple scripts using MorphCloud via the router.\n\n        Args:\n            scripts: List of code scripts to execute.\n            languages: List of programming languages for each script. If None, defaults to Python for all scripts.\n            timeout: Execution timeout in seconds. If None, uses the instance timeout.\n            request_timeout: HTTP request timeout in seconds. If None, uses the instance request_timeout.\n\n        Returns:\n            List of execution results with text and exception_str properties.\n        \"\"\"\n\n        actual_timeout = timeout if timeout is not None else self.timeout\n        actual_request_timeout = request_timeout if request_timeout is not None else self.request_timeout\n\n        # Default to Python for all scripts if languages is not provided\n        if languages is None:\n            languages = [\"python\"] * len(scripts)\n\n        payload = {\n            \"scripts\": scripts,\n            \"languages\": languages,\n            \"timeout\": actual_timeout,\n            \"request_timeout\": actual_request_timeout,\n        }\n\n        try:\n            endpoint = f\"http://{self.router_url}/execute_batch\"\n            response = requests.post(endpoint, json=payload, timeout=actual_request_timeout)\n\n            if response.status_code != 200:\n                error = f\"Request to MorphCloud router failed with status code: {response.status_code}\"\n                print(error)\n\n                results = []\n                for _ in scripts:\n                    results.append(type(\"obj\", (object,), {\"text\": None, \"exception_str\": error}))\n                return results\n\n            response_data = response.json()\n            results = []\n\n            for item in response_data:\n                # Log the response data to see what we're getting\n                # print(f\"RoutedMorphSandbox: Got response item: {item}\")\n                result = type(\n                    \"obj\",\n                    (object,),\n                    {\n                        \"text\": item.get(\"text\"),\n                        \"exception_str\": item.get(\"exception_str\"),\n                    },\n                )\n                results.append(result)\n\n            return results\n\n        except Exception as e:\n            error = f\"Error communicating with MorphCloud router: {str(e)}\"\n            print(error)\n\n            results = []\n            for _ in scripts:\n                results.append(type(\"obj\", (object,), {\"text\": None, \"exception_str\": error}))\n            return results\n"
  },
  {
    "path": "src/open_r1/utils/routed_sandbox.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional\n\nimport requests\nfrom e2b_code_interpreter.models import Execution, ExecutionError, Result\n\n\nclass RoutedSandbox:\n    \"\"\"\n    A sandbox environment that routes code execution requests to the E2B Router.\n    This class is designed for batched execution of scripts, primarily for Python code.\n    It mimics the usage of 'Sandbox' from 'e2b_code_interpreter', but adds support for batch processing.\n\n    Attributes:\n        router_url (str): The URL of the E2B Router to which code execution requests are sent.\n    \"\"\"\n\n    def __init__(self, router_url: str):\n        \"\"\"\n        Initializes the RoutedSandbox with the specified router URL.\n\n        Args:\n            router_url (str): The URL of the E2B Router.\n        \"\"\"\n        self.router_url = router_url\n\n    def run_code(\n        self,\n        scripts: list[str],\n        languages: Optional[List[str]] = None,\n        timeout: Optional[int] = None,\n        request_timeout: Optional[int] = None,\n    ) -> list[Execution]:\n        \"\"\"\n        Executes a batch of scripts in the sandbox environment.\n\n        Args:\n            scripts (list[str]): A list of code scripts to execute.\n            languages (list[str], optional): List of programming languages for each script. If None, defaults to Python for all scripts.\n            timeout (Optional[int], optional): The maximum execution time for each script in seconds. Defaults to 300 seconds.\n            request_timeout (Optional[int], optional): The timeout for the HTTP request in seconds. Defaults to 30 seconds.\n\n        Returns:\n            list[Execution]: A list of Execution objects containing the results, logs, and errors (if any) for each script.\n        \"\"\"\n        # Set default values for timeouts if not provided\n        if timeout is None:\n            timeout = 300  # Default to 5 minutes\n        if request_timeout is None:\n            request_timeout = 30  # Default to 30 seconds\n\n        # Default to Python for all scripts if languages is not provided\n        if languages is None:\n            languages = [\"python\"] * len(scripts)\n\n        # Prepare the payload for the HTTP POST request\n        payload = {\n            \"scripts\": scripts,\n            \"languages\": languages,\n            \"timeout\": timeout,\n            \"request_timeout\": request_timeout,\n        }\n\n        # Send the request to the E2B Router\n        response = requests.post(f\"http://{self.router_url}/execute_batch\", json=payload)\n        if not response.ok:\n            print(f\"Request failed with status code: {response.status_code}\")\n\n        # Parse the response and construct Execution objects\n        results = response.json()\n        output = []\n        for result in results:\n            if result[\"execution\"] is None:\n                # If execution is None, create an empty Execution object\n                # This can happen when a script times out or fails to execute\n                execution = Execution()\n            else:\n                execution = Execution(\n                    results=[Result(**r) for r in result[\"execution\"][\"results\"]],\n                    logs=result[\"execution\"][\"logs\"],\n                    error=(ExecutionError(**result[\"execution\"][\"error\"]) if result[\"execution\"][\"error\"] else None),\n                    execution_count=result[\"execution\"][\"execution_count\"],\n                )\n            output.append(execution)\n\n        return output\n\n\nif __name__ == \"__main__\":\n    # for local testing launch an E2B router with: python scripts/e2b_router.py\n    sbx = RoutedSandbox(router_url=\"0.0.0.0:8000\")\n    codes = [\"print('hello world')\", \"print('hello world)\"]\n    executions = sbx.run_code(codes)  # Execute Python inside the sandbox\n\n    print(executions)\n"
  },
  {
    "path": "src/open_r1/utils/wandb_logging.py",
    "content": "import os\n\n\ndef init_wandb_training(training_args):\n    \"\"\"\n    Helper function for setting up Weights & Biases logging tools.\n    \"\"\"\n    if training_args.wandb_entity is not None:\n        os.environ[\"WANDB_ENTITY\"] = training_args.wandb_entity\n    if training_args.wandb_project is not None:\n        os.environ[\"WANDB_PROJECT\"] = training_args.wandb_project\n    if training_args.wandb_run_group is not None:\n        os.environ[\"WANDB_RUN_GROUP\"] = training_args.wandb_run_group\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/slow/test_code_reward.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport unittest\n\nfrom datasets import load_dataset\n\nfrom e2b_code_interpreter.models import Execution, ExecutionError\nfrom open_r1.rewards import code_reward, ioi_code_reward\nfrom open_r1.utils.routed_morph import RoutedMorphSandbox\nfrom open_r1.utils.routed_sandbox import RoutedSandbox\n\n\nclass TestCodeRewards(unittest.TestCase):\n    def test_python_code_reward(self):\n        # requires E2B, see the README.md file\n        code_dataset = load_dataset(\"open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled\")\n        NUM_SAMPLES = 20\n        samples = code_dataset[\"train\"].select(range(NUM_SAMPLES))\n        test_completions = [[{\"content\": sample[\"gold_standard_solution\"]}] for sample in samples]\n        reward_kwargs = {\"verification_info\": [sample[\"verification_info\"] for sample in samples]}\n        rewards = code_reward(test_completions, **reward_kwargs)\n        print(rewards)\n        assert rewards == [1.0] * NUM_SAMPLES\n\n    def test_e2b_router(self):\n        # run router locally: python scripts/e2b_router.py\n        code_dataset = load_dataset(\"open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled\")\n        NUM_SAMPLES = 128\n        samples = code_dataset[\"train\"].select(range(NUM_SAMPLES))\n        test_completions = [[{\"content\": sample[\"gold_standard_solution\"]}] for sample in samples]\n        reward_kwargs = {\"verification_info\": [sample[\"verification_info\"] for sample in samples]}\n        rewards = code_reward(test_completions, e2b_router_url=\"0.0.0.0:8000\", **reward_kwargs)\n        print(rewards)\n        assert rewards == [1.0] * NUM_SAMPLES\n\n    def test_e2b_router_parallel(self):\n        # run router locally: python scripts/e2b_router.py\n        code_dataset = load_dataset(\"open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled\")\n\n        BATCH_SIZE = 32\n        NUM_SAMPLES = 256\n\n        def batch_code_reward(examples):\n            test_completions = [[{\"content\": solution}] for solution in examples[\"gold_standard_solution\"]]\n            reward_kwargs = {\n                \"verification_info\": [verification_info for verification_info in examples[\"verification_info\"]]\n            }\n            rewards = code_reward(test_completions, e2b_router_url=\"0.0.0.0:8000\", **reward_kwargs)\n            assert rewards == [1.0] * BATCH_SIZE\n            return examples\n\n        code_dataset = code_dataset[\"train\"].select(range(NUM_SAMPLES))\n        code_dataset = code_dataset.map(\n            batch_code_reward,\n            batched=True,\n            batch_size=BATCH_SIZE,\n            num_proc=4,\n            load_from_cache_file=False,\n        )\n\n    def test_ioi_code_reward(self):\n        # This slow test case requires spinning up a bunch (I tested with ~64) of piston workers, see docs here\n        # slurm/piston/README.md\n        code_dataset = load_dataset(\"open-r1/ioi-reward-test-dataset\")\n        NUM_SAMPLES = 16\n        samples = code_dataset[\"train\"].select(range(NUM_SAMPLES))\n        test_completions = [[{\"content\": f\"```cpp\\n{sample['sample_solution']}```\"}] for sample in samples]\n        keys = [key for key in samples[0] if key not in [\"prompt\", \"completion\"]]\n        reward_kwargs = {key: [example[key] for example in samples] for key in keys}\n        rewards = ioi_code_reward(test_completions, **reward_kwargs)\n        print(rewards)\n        assert rewards == [1.0] * NUM_SAMPLES\n\n    def test_e2b_router_run_code_success(self):\n        # run router locally: python scripts/e2b_router.py\n        routed_sandbox = RoutedSandbox(router_url=\"localhost:8000\")\n        scripts = [\n            \"print('hello from integration test')\",\n            \"result = 2 + 2\\nprint(result)\",\n        ]\n\n        results = routed_sandbox.run_code(scripts)\n\n        assert len(results) == 2\n\n        for result in results:\n            assert isinstance(result, Execution)\n            # assert result.exit_code == 0\n            assert result.error is None\n            assert \"hello\" in result.logs[\"stdout\"][0] or \"4\" in result.logs[\"stdout\"][0]\n\n    def test_e2b_router_run_code_with_error(self):\n        # run router locally: python scripts/e2b_router.py\n\n        routed_sandbox = RoutedSandbox(router_url=\"localhost:8000\")\n        scripts = [\"print('this is fine')\", \"print('unterminated string\"]\n\n        results = routed_sandbox.run_code(scripts)\n\n        assert len(results) == 2\n\n        # First one should be okay\n        # assert results[0].exit_code == 0 # Execution object has no attribute 'exit_code'\n        assert results[0].error is None\n        assert \"this is fine\" in results[0].logs[\"stdout\"][0]\n\n        # Second one should have a syntax error\n\n        # assert results[1].exit_code != 0 # Execution object has no attribute 'exit_code'\n        assert results[1].error is not None\n        assert isinstance(results[1].error, ExecutionError)\n        assert \"SyntaxError\" in results[1].error.name\n\n    def test_python_code_reward_morph(self):\n        # requires MorphCloud, see the README.md file\n        code_dataset = load_dataset(\"open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled\")\n        NUM_SAMPLES = 20\n        samples = code_dataset[\"train\"].select(range(NUM_SAMPLES))\n        test_completions = [[{\"content\": sample[\"gold_standard_solution\"]}] for sample in samples]\n        reward_kwargs = {\n            \"verification_info\": [sample[\"verification_info\"] for sample in samples],\n            \"provider_type\": \"morph\",\n        }\n        rewards = code_reward(test_completions, **reward_kwargs)\n        print(rewards)\n        assert rewards == [1.0] * NUM_SAMPLES\n\n    def test_morph_router(self):\n        # run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20\n        code_dataset = load_dataset(\"open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled\")\n        NUM_SAMPLES = 32\n        samples = code_dataset[\"train\"].select(range(NUM_SAMPLES))\n        test_completions = [[{\"content\": sample[\"gold_standard_solution\"]}] for sample in samples]\n        reward_kwargs = {\n            \"verification_info\": [sample[\"verification_info\"] for sample in samples],\n            \"provider_type\": \"morph\",\n            \"morph_router_url\": \"0.0.0.0:8001\",\n        }\n        rewards = code_reward(test_completions, **reward_kwargs)\n        print(rewards)\n        assert rewards == [1.0] * NUM_SAMPLES\n\n    def test_morph_router_parallel(self):\n        # run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20\n        code_dataset = load_dataset(\"open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled\")\n\n        BATCH_SIZE = 32\n        NUM_SAMPLES = 256\n\n        def batch_code_reward(examples):\n            test_completions = [[{\"content\": solution}] for solution in examples[\"gold_standard_solution\"]]\n            reward_kwargs = {\n                \"verification_info\": [verification_info for verification_info in examples[\"verification_info\"]],\n                \"provider_type\": \"morph\",\n                \"morph_router_url\": \"0.0.0.0:8001\",\n            }\n            rewards = code_reward(test_completions, **reward_kwargs)\n            assert rewards == [1.0] * BATCH_SIZE\n            return examples\n\n        code_dataset = code_dataset[\"train\"].select(range(NUM_SAMPLES))\n        code_dataset = code_dataset.map(\n            batch_code_reward,\n            batched=True,\n            batch_size=BATCH_SIZE,\n            num_proc=4,\n            load_from_cache_file=False,\n        )\n\n    def test_morph_router_run_code_success(self):\n        # run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20\n\n        routed_sandbox = RoutedMorphSandbox(router_url=\"localhost:8001\")\n        scripts = [\n            \"print('hello from morph integration test')\",\n            \"result = 2 + 2\\nprint(result)\",\n        ]\n\n        results = routed_sandbox.run_code(scripts)\n\n        assert len(results) == 2\n\n        for result in results:\n            assert result.exception_str is None\n            assert \"hello\" in result.text or \"4\" in result.text\n\n    def test_morph_router_run_code_with_error(self):\n        # run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20\n\n        routed_sandbox = RoutedMorphSandbox(router_url=\"localhost:8001\")\n        scripts = [\"print('this is fine with morph')\", \"print('unterminated string\"]\n\n        results = routed_sandbox.run_code(scripts)\n\n        assert len(results) == 2\n\n        # First one should be okay\n        assert results[0].exception_str is None\n        assert \"this is fine with morph\" in results[0].text\n\n        # Second one should have a syntax error\n        assert \"SyntaxError\" in results[1].text\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_rewards.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport unittest\n\nfrom dotenv import load_dotenv\nfrom open_r1.configs import GRPOScriptArguments\nfrom open_r1.rewards import (\n    accuracy_reward,\n    format_reward,\n    get_code_format_reward,\n    get_cosine_scaled_reward,\n    get_repetition_penalty_reward,\n    get_reward_funcs,\n    get_soft_overlong_punishment,\n    len_reward,\n    reasoning_steps_reward,\n    tag_count_reward,\n)\n\n\nload_dotenv()\n\n\nclass TestGetRewardFuncs(unittest.TestCase):\n    def test_get_reward_funcs(self):\n        \"\"\"Test get_reward_funcs with various reward functions.\"\"\"\n        reward_names = [\n            \"accuracy\",\n            \"format\",\n            \"reasoning_steps\",\n            \"cosine\",\n            \"repetition_penalty\",\n            \"length\",\n            \"tag_count\",\n            \"code\",\n            \"ioi_code\",\n            \"code_format\",\n            \"binary_code\",\n        ]\n        reward_func_names = [\n            \"accuracy_reward\",\n            \"format_reward\",\n            \"reasoning_steps_reward\",\n            \"cosine_scaled_reward\",\n            \"repetition_penalty_reward\",\n            \"len_reward\",\n            \"tag_count_reward\",\n            \"code_reward\",\n            \"ioi_code_reward\",\n            \"code_format_reward\",\n            \"binary_code_reward\",\n        ]\n\n        args = GRPOScriptArguments(\n            dataset_name=\"dummy\",\n            reward_funcs=reward_names,\n        )\n\n        reward_funcs = get_reward_funcs(args)\n        self.assertEqual(len(reward_funcs), 11)\n        for func_name, func in zip(reward_func_names, reward_funcs):\n            self.assertEqual(func_name, func.__name__)\n\n\nclass TestRewards(unittest.TestCase):\n    def test_accuracy_reward_correct_answer(self):\n        \"\"\"Test accuracy_reward with a correct answer.\"\"\"\n        completion = [[{\"content\": r\"\\boxed{\\frac{63}{400}}\"}]]\n        solution = [r\"\\frac{63}{400}\"]\n        rewards = accuracy_reward(completion, solution)\n        self.assertEqual(rewards[0], 1.0)\n\n    def test_accuracy_reward_wrong_answer(self):\n        \"\"\"Test accuracy_reward with an incorrect answer.\"\"\"\n        completion = [[{\"content\": r\"\\boxed{\\frac{64}{400}}\"}]]\n        solution = [r\"\\frac{63}{400}\"]\n        rewards = accuracy_reward(completion, solution)\n        self.assertEqual(rewards[0], 0.0)\n\n    def test_accuracy_reward_wrong_answer_no_latex(self):\n        \"\"\"Test accuracy_reward with an incorrect answer and gold solution with no latex.\"\"\"\n        completion = [[{\"content\": r\"\\boxed{3}\"}]]\n        solution = [\"6\"]\n        rewards = accuracy_reward(completion, solution)\n        self.assertEqual(rewards[0], 0.0)\n\n    def test_format_reward_correct(self):\n        \"\"\"Test format_reward with correct format.\"\"\"\n        completion = [[{\"content\": \"<think>\\nSome reasoning\\n</think>\\n<answer>\\nThe answer\\n</answer>\"}]]\n        rewards = format_reward(completion)\n        self.assertEqual(rewards[0], 1.0)\n\n    def test_format_reward_incorrect(self):\n        \"\"\"Test format_reward with incorrect format.\"\"\"\n        incorrect_formats = [\n            \"<think>Only thinking</think>\",\n            \"<answer>Only answer</answer>\",\n            \"No tags at all\",\n            \"<think>Missing closing</think><answer>Missing closing\",\n            \"<think>Wrong order</answer><answer>Wrong order</think>\",\n        ]\n\n        for fmt in incorrect_formats:\n            completion = [[{\"content\": fmt}]]\n            rewards = format_reward(completion)\n            self.assertEqual(rewards[0], 0.0)\n\n    def test_reasoning_steps_reward(self):\n        \"\"\"Test reasoning_steps_reward with various formats.\"\"\"\n        test_cases = [\n            # Full credit cases (3 or more steps)\n            (\"Step 1: First step\\nStep 2: Second step\\nStep 3: Third step\", 1.0),\n            (\"First, we do this.\\nSecond, we do that.\\nFinally, we conclude.\", 1.0),\n            # Partial credit cases (less than 3 steps)\n            (\"Step 1: Only step\", 1 / 3),\n            (\"First, we do this.\\nFinally, we conclude.\", 2 / 3),\n            # No credit case\n            (\"Just plain text without any clear steps\", 0.0),\n        ]\n\n        for content, expected_reward in test_cases:\n            completion = [[{\"content\": content}]]\n            rewards = reasoning_steps_reward(completion)\n            self.assertAlmostEqual(rewards[0], expected_reward)\n\n    def test_multiple_completions(self):\n        \"\"\"Test handling multiple completions at once.\"\"\"\n        completions = [\n            [{\"content\": r\"\\boxed{\\frac{63}{400}}\"}],\n            [{\"content\": r\"\\boxed{\\frac{64}{400}}\"}],\n        ]\n        solutions = [r\"\\frac{63}{400}\", r\"\\frac{63}{400}\"]\n\n        rewards = accuracy_reward(completions, solutions)\n        self.assertEqual(len(rewards), 2)\n        self.assertEqual(rewards[0], 1.0)\n        self.assertEqual(rewards[1], 0.0)\n\n    def test_cosine_scaled_reward(self):\n        \"\"\"Test cosine_scaled_reward with various cases.\"\"\"\n        # Test parameters\n        test_params = {\n            \"min_value_wrong\": -1.0,\n            \"max_value_wrong\": -0.5,\n            \"min_value_correct\": 0.5,\n            \"max_value_correct\": 1.0,\n            \"max_len\": 100,\n        }\n\n        test_cases = [\n            # Correct answers with different lengths\n            (\n                r\"\\boxed{\\frac{63}{400}}\",\n                r\"\\frac{63}{400}\",\n                20,\n                0.943,\n            ),  # Short correct answer\n            (\n                r\"\\boxed{\\frac{63}{400}}\",\n                r\"\\frac{63}{400}\",\n                80,\n                0.547,\n            ),  # Long correct answer\n            # Wrong answers with different lengths\n            (\n                r\"\\boxed{\\frac{64}{400}}\",\n                r\"\\frac{63}{400}\",\n                20,\n                -0.942,\n            ),  # Short wrong answer\n            (\n                r\"\\boxed{\\frac{64}{400}}\",\n                r\"\\frac{63}{400}\",\n                80,\n                -0.547,\n            ),  # Long wrong answer\n        ]\n\n        for content, solution, content_len, expected_reward in test_cases:\n            # Pad content to desired length\n            padded_content = content + \" \" * (content_len - len(content))\n            completion = [[{\"content\": padded_content}]]\n\n            rewards = get_cosine_scaled_reward(**test_params)(completion, [solution])\n            self.assertAlmostEqual(rewards[0], expected_reward, places=2)\n\n    def test_format_reward_specific_multiline(self):\n        \"\"\"Test format_reward with a specific multiline input.\"\"\"\n        inputs = \"<think>\\nI will count each distinct object in the image:\\n1. Purple scooter\\n2. Red bicycle\\n3. Green motorcycle\\n4. Gray sedan\\n5. Yellow school bus\\n6. Small green double-decker bus\\n7. Small red car\\n8. Small purple car\\n9. Small gray dirt bike\\n\\nThere are 9 distinct objects in total.\\n</think>\\n<answer>\\n9\\n</answer>\"\n        completion = [[{\"content\": inputs}]]\n        rewards = format_reward(completion)\n        self.assertEqual(rewards[0], 1.0)\n\n    def test_same_length_responses(self):\n        \"\"\"Test len_reward when all responses have the same length.\"\"\"\n        completions = [\n            [{\"content\": r\"\\boxed{\\frac{63}{400}}\"}],\n            [{\"content\": r\"\\boxed{\\frac{64}{400}}\"}],\n        ]\n        solutions = [r\"\\frac{63}{400}\", r\"\\frac{63}{400}\"]\n\n        rewards = len_reward(completions, solutions)\n        self.assertEqual(rewards, [0.0, 0.0])\n\n    def test_different_lengths_correct_answers(self):\n        \"\"\"Test len_reward with different length correct answers.\"\"\"\n        completions = [\n            [{\"content\": r\"\\boxed{\\frac{63}{400}}\"}],  # shorter\n            [{\"content\": r\"\\boxed{\\frac{63}{400}}  \" + \"x\" * 10}],  # longer\n        ]\n        solutions = [r\"\\frac{63}{400}\", r\"\\frac{63}{400}\"]\n\n        rewards = len_reward(completions, solutions)\n        self.assertGreater(rewards[0], rewards[1])  # shorter answer should get higher reward\n        self.assertAlmostEqual(rewards[0], 0.5)  # shortest correct answer gets maximum reward\n\n    def test_different_lengths_incorrect_answers(self):\n        \"\"\"Test len_reward with different length incorrect answers.\"\"\"\n        completions = [\n            [{\"content\": r\"\\boxed{\\frac{64}{400}}\"}],  # shorter\n            [{\"content\": r\"\\boxed{\\frac{64}{400}}  \" + \"x\" * 10}],  # longer\n        ]\n        solutions = [r\"\\frac{63}{400}\", r\"\\frac{63}{400}\"]\n\n        rewards = len_reward(completions, solutions)\n        self.assertLessEqual(rewards[0], 0.0)  # incorrect answers should get non-positive rewards\n        self.assertLessEqual(rewards[1], 0.0)\n        self.assertGreater(rewards[0], rewards[1])  # shorter answer should still be penalized less\n\n    def test_mixed_correctness(self):\n        \"\"\"Test len_reward with mix of correct and incorrect answers of different lengths.\"\"\"\n        completions = [\n            [{\"content\": r\"\\boxed{\\frac{63}{400}}\"}],  # correct, shorter\n            [{\"content\": r\"\\boxed{\\frac{63}{400}}  \" + \"x\" * 10}],  # correct, longer\n            [{\"content\": r\"\\boxed{\\frac{64}{400}}\"}],  # incorrect, shorter\n            [{\"content\": r\"\\boxed{\\frac{64}{400}}  \" + \"x\" * 10}],  # incorrect, longer\n        ]\n        solutions = [r\"\\frac{63}{400}\"] * 4\n\n        rewards = len_reward(completions, solutions)\n\n        # Shortest correct answer should get positive reward\n        self.assertGreater(rewards[0], 0.0)\n\n        # Longer correct answer might get negative reward:\n        self.assertGreater(rewards[2], rewards[1])\n        self.assertGreaterEqual(rewards[1], rewards[3])\n\n        # Incorrect answers should get non-positive rewards\n        self.assertLessEqual(rewards[2], 0.0)\n        self.assertLessEqual(rewards[3], 0.0)\n\n        # Shorter answers should get better rewards within their correctness category\n        self.assertGreater(rewards[0], rewards[1])  # correct answers\n        self.assertGreater(rewards[2], rewards[3])  # incorrect answers\n\n    def test_unparseable_solution(self):\n        \"\"\"Test len_reward with unparseable solution.\"\"\"\n        completions = [\n            [{\"content\": r\"\\boxed{answer}\"}],\n            [{\"content\": r\"\\boxed{answer} \" + \"x\" * 10}],\n        ]\n        solutions = [\"unparseable_latex\", \"unparseable_latex\"]\n\n        rewards = len_reward(completions, solutions)\n        self.assertGreater(rewards[0], rewards[1])  # shorter answer should still get better reward\n        self.assertAlmostEqual(rewards[0], 0.5)  # treated as correct, shortest gets maximum reward\n\n\nclass TestRepetitionPenaltyReward(unittest.TestCase):\n    def test_positive_max_penalty_raises_value_error(self):\n        with self.assertRaises(ValueError):\n            get_repetition_penalty_reward(ngram_size=2, max_penalty=1.0)\n        with self.assertRaisesRegex(ValueError, \"max_penalty 1.5 should not be positive\"):\n            get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5)\n\n    def test_no_repetition(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)\n        completions = [[{\"content\": \"this is a test sentence\"}]]\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [0.0])\n\n    def test_full_repetition(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)\n        completions = [[{\"content\": \"this this this this this\"}]]\n\n        rewards = reward_fn(completions)\n        # (1 - 1/4) * -1 = -0.75\n        self.assertEqual(rewards, [-0.75])\n\n    def test_partial_repetition(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)\n        completions = [[{\"content\": \"this is a this is a test\"}]]\n\n        rewards = reward_fn(completions)\n        # Unique 2-grams: (this, is), (is, a), (a, this), (a, test).  4 unique out of 6 total\n        # (1 - 4/6) * -1 = -1/3 = -0.3333...\n        self.assertAlmostEqual(rewards[0], -1 / 3)\n\n    def test_multiple_completions(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)\n        completions = [\n            [{\"content\": \"this is a test\"}],\n            [{\"content\": \"test test test test\"}],\n        ]\n\n        rewards = reward_fn(completions)\n        # Completion 1:  (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0\n        # Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25\n        self.assertAlmostEqual(rewards[0], 0.0)\n        self.assertAlmostEqual(rewards[1], -0.25)\n\n    def test_empty_completion(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)\n        completions = [[{\"content\": \"\"}]]\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [0.0])\n\n    def test_different_ngram_size(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0)\n        completions = [[{\"content\": \"this is a this is a test\"}]]\n\n        rewards = reward_fn(completions)\n        self.assertAlmostEqual(rewards[0], -0.4)\n\n    def test_mixed_case(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)\n        completions = [\n            [{\"content\": \"This is A Test\"}],\n            [{\"content\": \"this IS a test\"}],\n        ]\n\n        rewards = reward_fn(completions)\n        # both completions should produce the same reward, because the text gets lowercased\n        self.assertAlmostEqual(rewards[0], rewards[1])\n\n    def test_one_word_completion(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)\n        completions = [[{\"content\": \"word\"}]]\n\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [0.0])\n\n    def test_two_word_completion(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)\n        completions = [[{\"content\": \"two words\"}]]\n\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [0.0])\n\n    def test_three_word_completion(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)\n        completions = [[{\"content\": \"three different words\"}]]\n\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [0.0])\n\n    def test_three_word_repetition_completion(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)\n        completions = [[{\"content\": \"word word word word\"}]]\n\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [-0.5])\n\n    def test_four_word_completion_with_repetition(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)\n        completions = [[{\"content\": \"one two one two\"}]]\n\n        rewards = reward_fn(completions)\n        # ngrams are (one two one) (two one two). unique is 2 and count is 2, therefore (1-1) * -1.\n        self.assertEqual(rewards, [0.0])\n\n    def test_five_word_completion_with_repetition(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)\n        completions = [[{\"content\": \"A B C A B\"}]]\n\n        rewards = reward_fn(completions)\n        # (A B C) (B C A) (C A B). unique is 3. count is 3 (1-1) * -.5 = 0\n        self.assertEqual(rewards, [0.0])\n\n    def test_six_word_completion_with_repetition(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)\n        completions = [[{\"content\": \"A B C A B C\"}]]\n\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [-0.25])\n\n    def test_long_completion_with_repetition(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)\n        completions = [[{\"content\": \"A B C A B C E F G A B C A B C\"}]]\n        rewards = reward_fn(completions)\n        self.assertAlmostEqual(rewards[0], -0.3846, places=4)\n\n    def test_long_completion_without_repetition(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)\n        completions = [[{\"content\": \"A B C D E F G H I J K L\"}]]\n\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [0.0])\n\n    def test_tag_count_rewards_all_correct(self):\n        \"\"\"Test tag_count_reward with correct tags.\"\"\"\n        completion = [[{\"content\": \"<think>\\nSome reasoning\\n</think>\\n<answer>\\nThe answer\\n</answer>\"}]]\n        rewards = tag_count_reward(completion)\n        self.assertEqual(rewards[0], 1.0)\n\n    def test_tag_count_rewards_missing_think_begin(self):\n        \"\"\"Test tag_count_reward with missing <think> tag.\"\"\"\n        completion = [[{\"content\": \"Some reasoning\\n</think>\\n<answer>\\nThe answer\\n</answer>\"}]]\n        rewards = tag_count_reward(completion)\n        self.assertEqual(rewards[0], 0.75)\n\n    def test_tag_count_rewards_missing_think_end(self):\n        \"\"\"Test tag_count_reward with missing </think> tag.\"\"\"\n        completion = [[{\"content\": \"<think>\\nSome reasoning\\n<answer>\\nThe answer\\n</answer>\"}]]\n        rewards = tag_count_reward(completion)\n        self.assertEqual(rewards[0], 0.75)\n\n    def test_tag_count_rewards_missing_answer_begin(self):\n        \"\"\"Test tag_count_reward with missing <answer> tag.\"\"\"\n        completion = [[{\"content\": \"<think>\\nSome reasoning\\n</think>\\nThe answer\\n</answer>\"}]]\n        rewards = tag_count_reward(completion)\n        self.assertEqual(rewards[0], 0.75)\n\n    def test_tag_count_rewards_missing_answer_end(self):\n        \"\"\"Test tag_count_reward with missing </answer> tag.\"\"\"\n        completion = [[{\"content\": \"<think>\\nSome reasoning\\n</think>\\n<answer>\\nThe answer\"}]]\n        rewards = tag_count_reward(completion)\n        self.assertEqual(rewards[0], 0.75)\n\n    def test_tag_count_rewards_missing_all_tags(self):\n        \"\"\"Test tag_count_reward with missing all tags.\"\"\"\n        completion = [[{\"content\": \"Some reasoning\\nThe answer\"}]]\n        rewards = tag_count_reward(completion)\n        self.assertEqual(rewards[0], 0.0)\n\n    def test_full_repetition_with_language(self):\n        reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language=\"en\")\n        completions = [[{\"content\": \"that that that that that\"}]]\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [-0.75])\n        # begin test for zh language\n        reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language=\"zh\")\n        completions = [[{\"content\": \"这个这个这个这个这个\"}]]\n        rewards = reward_fn(completions)\n        self.assertEqual(rewards, [-0.75])\n\n    def test_soft_overlong_punishment_short_completion(self):\n        \"\"\"Test soft overlong punishment reward function with a short completion.\"\"\"\n        # length 50, with max=100 and soft cache=20, reward should be 0.\n        reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)\n        completion_ids = [[1] * 50]  # 50 <= 80\n        rewards = reward_fn(completion_ids=completion_ids)\n        self.assertEqual(rewards, [0])\n\n    def test_soft_overlong_punishment_long_completion(self):\n        \"\"\"Test soft overlong punishment reward function with a longer than max completion.\"\"\"\n        # 110 > 100, reward should be -1.\n        reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)\n        completion_ids = [[1] * 110]\n        rewards = reward_fn(completion_ids)\n        self.assertEqual(rewards, [-1])\n\n    def test_soft_overlong_punishment_intermediate_completion(self):\n        \"\"\"Test soft overlong punishment reward function for intermediate length completion.\"\"\"\n        reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)\n        completion_ids = [[1] * 90]  # 90 is between 80 and 100\n        rewards = reward_fn(completion_ids)\n        self.assertAlmostEqual(rewards[0], -0.5, places=4)\n\n\nclass TestCodeFormat(unittest.TestCase):\n    def test_correct_python_format(self):\n        \"\"\"Test code format reward with correct Python format.\"\"\"\n        completion = [\n            [\n                {\n                    \"content\": \"<think>\\nLet's solve this\\nStep 1: First step\\n</think>\\n<answer>\\n```python\\ndef hello():\\n    print('world')\\n```\\n</answer>\"\n                }\n            ]\n        ]\n        reward_fn = get_code_format_reward(language=\"python\")\n        rewards = reward_fn(completion)\n        self.assertEqual(rewards[0], 1.0)\n\n    def test_incorrect_formats(self):\n        \"\"\"Test code format reward with various incorrect formats.\"\"\"\n        incorrect_formats = [\n            # Missing think/answer tags\n            \"```python\\ndef hello():\\n    print('world')\\n```\",\n            # Missing code block\n            \"<think>Some thinking</think><answer>Just plain text</answer>\",\n            # Wrong language\n            \"<think>Analysis</think><answer>```javascript\\nconsole.log('hello');\\n```</answer>\",\n            # Missing language identifier\n            \"<think>Analysis</think><answer>```\\ndef hello(): pass\\n```</answer>\",\n            # Wrong order of tags\n            \"<answer>```python\\ndef hello(): pass\\n```</answer><think>Analysis</think>\",\n        ]\n\n        reward_fn = get_code_format_reward(language=\"python\")\n        for fmt in incorrect_formats:\n            completion = [[{\"content\": fmt}]]\n            rewards = reward_fn(completion)\n            self.assertEqual(rewards[0], 0.0)\n\n    def test_multiple_code_blocks(self):\n        \"\"\"Test format reward with multiple code blocks in think and answer sections.\"\"\"\n        completion = [\n            [\n                {\n                    \"content\": \"<think>\\nHere's an example:\\n```python\\nx = 1\\n```\\nNow the solution:\\n</think>\\n<answer>\\n```python\\ndef solution():\\n    return 42\\n```\\n</answer>\"\n                }\n            ]\n        ]\n        reward_fn = get_code_format_reward(language=\"python\")\n        rewards = reward_fn(completion)\n        self.assertEqual(rewards[0], 1.0)\n\n    def test_different_languages(self):\n        \"\"\"Test code format reward with different programming languages.\"\"\"\n        completion = [\n            [\n                {\n                    \"content\": \"<think>\\nAnalysis\\n</think>\\n<answer>\\n```javascript\\nconsole.log('hello');\\n```\\n</answer>\"\n                }\n            ]\n        ]\n\n        # Test with JavaScript\n        js_reward_fn = get_code_format_reward(language=\"javascript\")\n        rewards = js_reward_fn(completion)\n        self.assertEqual(rewards[0], 1.0)\n\n        # Same completion should fail for Python\n        py_reward_fn = get_code_format_reward(language=\"python\")\n        rewards = py_reward_fn(completion)\n        self.assertEqual(rewards[0], 0.0)\n\n    def test_multiline_code(self):\n        \"\"\"Test format reward with complex multiline code blocks.\"\"\"\n        completion = [\n            [\n                {\n                    \"content\": \"<think>\\nHere's the analysis\\n</think>\\n<answer>\\n```python\\nclass Solution:\\n    def __init__(self):\\n        self.value = 42\\n        \\n    def get_value(self):\\n        return self.value\\n```\\n</answer>\"\n                }\n            ]\n        ]\n        reward_fn = get_code_format_reward(language=\"python\")\n        rewards = reward_fn(completion)\n        self.assertEqual(rewards[0], 1.0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_data.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport unittest\nfrom dataclasses import asdict\n\nfrom datasets import DatasetDict, load_dataset\n\nfrom open_r1.configs import DatasetConfig, DatasetMixtureConfig, ScriptArguments\nfrom open_r1.utils.data import get_dataset\n\n\nclass TestGetDataset(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        cls.dataset_name = \"trl-internal-testing/zen\"\n        cls.dataset_config = \"conversational_preference\"\n        cls.ref_dataset = load_dataset(cls.dataset_name, cls.dataset_config)\n\n    def test_dataset_and_config_name(self):\n        args = ScriptArguments(dataset_name=self.dataset_name, dataset_config=self.dataset_config)\n        dataset = get_dataset(args)\n        self.assertIsInstance(dataset, DatasetDict)\n        self.assertIn(\"train\", dataset)\n        self.assertEqual(len(dataset[\"train\"]), len(self.ref_dataset[\"train\"]))\n\n    def test_unweighted_mixture(self):\n        \"\"\"Mix train and test splits of the same dataset.\"\"\"\n        dataset_configs = [\n            DatasetConfig(id=self.dataset_name, config=self.dataset_config, split=\"train\", columns=None, weight=None),\n            DatasetConfig(id=self.dataset_name, config=self.dataset_config, split=\"test\", columns=None, weight=None),\n        ]\n        dataset_mixture = DatasetMixtureConfig(\n            datasets=dataset_configs,\n        )\n        args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))\n        dataset = get_dataset(args)\n        self.assertIsInstance(dataset, DatasetDict)\n        self.assertIn(\"train\", dataset)\n        self.assertEqual(len(dataset[\"train\"]), len(self.ref_dataset[\"train\"]) + len(self.ref_dataset[\"test\"]))\n\n    def test_weighted_mixture(self):\n        \"\"\"Test loading a dataset mixture with weights.\"\"\"\n        dataset_configs = [\n            DatasetConfig(id=self.dataset_name, config=self.dataset_config, split=\"train\", columns=None, weight=0.25),\n            DatasetConfig(id=self.dataset_name, config=self.dataset_config, split=\"test\", columns=None, weight=0.5),\n        ]\n        dataset_mixture = DatasetMixtureConfig(\n            datasets=dataset_configs,\n        )\n        args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))\n        dataset = get_dataset(args)\n        self.assertIsInstance(dataset, DatasetDict)\n        self.assertIn(\"train\", dataset)\n        self.assertEqual(\n            len(dataset[\"train\"]), len(self.ref_dataset[\"train\"]) // 4 + len(self.ref_dataset[\"test\"]) // 2\n        )\n\n    def test_mixture_and_test_split(self):\n        \"\"\"Test loading a dataset mixture with test split.\"\"\"\n        dataset_configs = [\n            DatasetConfig(\n                id=self.dataset_name, config=self.dataset_config, split=\"train[:10]\", columns=None, weight=None\n            ),\n        ]\n        dataset_mixture = DatasetMixtureConfig(datasets=dataset_configs, test_split_size=0.2)\n        args = ScriptArguments(dataset_name=None, dataset_mixture=asdict(dataset_mixture))\n        dataset = get_dataset(args)\n        self.assertIsInstance(dataset, DatasetDict)\n        self.assertIn(\"train\", dataset)\n        self.assertIn(\"test\", dataset)\n        self.assertEqual(len(dataset[\"train\"]), 8)\n        self.assertEqual(len(dataset[\"test\"]), 2)\n\n    def test_mixture_column_selection(self):\n        \"\"\"Test loading a dataset mixture with column selection.\"\"\"\n        dataset_configs = [\n            DatasetConfig(\n                id=self.dataset_name,\n                config=self.dataset_config,\n                split=\"train\",\n                columns=[\"prompt\", \"chosen\"],\n                weight=None,\n            ),\n        ]\n        dataset_mixture = DatasetMixtureConfig(\n            datasets=dataset_configs,\n        )\n        args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))\n        dataset = get_dataset(args)\n        self.assertIsInstance(dataset, DatasetDict)\n        self.assertIn(\"train\", dataset)\n        self.assertIn(\"prompt\", dataset[\"train\"].column_names)\n        self.assertIn(\"chosen\", dataset[\"train\"].column_names)\n\n    def test_mixture_with_mismatched_columns(self):\n        dataset_configs = [\n            DatasetConfig(\n                id=self.dataset_name, config=self.dataset_config, split=\"train\", columns=[\"prompt\"], weight=None\n            ),\n            DatasetConfig(\n                id=self.dataset_name, config=self.dataset_config, split=\"train\", columns=[\"chosen\"], weight=None\n            ),\n        ]\n        dataset_mixture = DatasetMixtureConfig(\n            datasets=dataset_configs,\n        )\n        with self.assertRaises(ValueError) as context:\n            _ = ScriptArguments(dataset_mixture=asdict(dataset_mixture))\n        self.assertIn(\"Column names must be consistent\", str(context.exception))\n\n    def test_no_dataset_name_or_mixture(self):\n        with self.assertRaises(ValueError) as context:\n            _ = ScriptArguments(dataset_name=None, dataset_mixture=None)\n        self.assertIn(\"Either `dataset_name` or `dataset_mixture` must be provided\", str(context.exception))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  }
]