Showing preview only (6,665K chars total). Download the full file or copy to clipboard to get everything.
Repository: huggingface/trl
Branch: main
Commit: 8e6e0626ebec
Files: 380
Total size: 6.3 MB
Directory structure:
gitextract_r678upi2/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug-report.yml
│ │ ├── feature-request.yml
│ │ └── new-trainer-addition.yml
│ ├── PULL_REQUEST_TEMPLATE.md
│ ├── codeql/
│ │ └── custom-queries.qls
│ └── workflows/
│ ├── build_documentation.yml
│ ├── build_pr_documentation.yml
│ ├── clear_cache.yml
│ ├── codeQL.yml
│ ├── docker-build.yml
│ ├── issue_auto_labeller.yml
│ ├── pr_style_bot.yml
│ ├── publish.yml
│ ├── slow-tests.yml
│ ├── tests-experimental.yml
│ ├── tests.yml
│ ├── tests_latest.yml
│ ├── tests_transformers_branch.yml
│ ├── trufflehog.yml
│ └── upload_pr_documentation.yml
├── .gitignore
├── .pre-commit-config.yaml
├── AGENTS.md
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── MIGRATION.md
├── Makefile
├── README.md
├── RELEASE.md
├── VERSION
├── docker/
│ ├── trl/
│ │ └── Dockerfile
│ └── trl-dev/
│ └── Dockerfile
├── docs/
│ └── source/
│ ├── _toctree.yml
│ ├── async_grpo_trainer.md
│ ├── bco_trainer.md
│ ├── bema_for_reference_model.md
│ ├── callbacks.md
│ ├── chat_template_utils.md
│ ├── clis.md
│ ├── community_tutorials.md
│ ├── cpo_trainer.md
│ ├── customization.md
│ ├── data_utils.md
│ ├── dataset_formats.md
│ ├── deepspeed_integration.md
│ ├── distributing_training.md
│ ├── dpo_trainer.md
│ ├── example_overview.md
│ ├── experimental_overview.md
│ ├── gfpo.md
│ ├── gkd_trainer.md
│ ├── gold_trainer.md
│ ├── grpo_trainer.md
│ ├── grpo_with_replay_buffer.md
│ ├── gspo_token.md
│ ├── index.md
│ ├── installation.md
│ ├── jobs_training.md
│ ├── judges.md
│ ├── kernels_hub.md
│ ├── kto_trainer.md
│ ├── liger_kernel_integration.md
│ ├── lora_without_regret.md
│ ├── merge_model_callback.md
│ ├── minillm_trainer.md
│ ├── nash_md_trainer.md
│ ├── nemo_gym.md
│ ├── online_dpo_trainer.md
│ ├── openenv.md
│ ├── orpo_trainer.md
│ ├── paper_index.md
│ ├── papo_trainer.md
│ ├── peft_integration.md
│ ├── ppo_trainer.md
│ ├── prm_trainer.md
│ ├── ptt_integration.md
│ ├── quickstart.md
│ ├── rapidfire_integration.md
│ ├── reducing_memory_usage.md
│ ├── reward_trainer.md
│ ├── rewards.md
│ ├── rloo_trainer.md
│ ├── script_utils.md
│ ├── sft_trainer.md
│ ├── speeding_up_training.md
│ ├── trackio_integration.md
│ ├── unsloth_integration.md
│ ├── use_model.md
│ ├── vllm_integration.md
│ ├── winrate_callback.md
│ └── xpo_trainer.md
├── examples/
│ ├── README.md
│ ├── accelerate_configs/
│ │ ├── alst_ulysses_4gpu.yaml
│ │ ├── context_parallel_2gpu.yaml
│ │ ├── deepspeed_zero1.yaml
│ │ ├── deepspeed_zero2.yaml
│ │ ├── deepspeed_zero3.yaml
│ │ ├── fsdp1.yaml
│ │ ├── fsdp2.yaml
│ │ ├── multi_gpu.yaml
│ │ └── single_gpu.yaml
│ ├── cli_configs/
│ │ └── example_config.yaml
│ ├── datasets/
│ │ ├── deepmath_103k.py
│ │ ├── hh-rlhf-helpful-base.py
│ │ ├── llava_instruct_mix.py
│ │ ├── lm-human-preferences-descriptiveness.py
│ │ ├── lm-human-preferences-sentiment.py
│ │ ├── math_shepherd.py
│ │ ├── prm800k.py
│ │ ├── rlaif-v.py
│ │ ├── tldr.py
│ │ ├── tldr_preference.py
│ │ ├── ultrafeedback-prompt.py
│ │ └── ultrafeedback.py
│ ├── notebooks/
│ │ ├── README.md
│ │ ├── grpo_agent.ipynb
│ │ ├── grpo_functiongemma_browsergym_openenv.ipynb
│ │ ├── grpo_ministral3_vl.ipynb
│ │ ├── grpo_qwen3_vl.ipynb
│ │ ├── grpo_rnj_1_instruct.ipynb
│ │ ├── grpo_trl_lora_qlora.ipynb
│ │ ├── openenv_sudoku_grpo.ipynb
│ │ ├── openenv_wordle_grpo.ipynb
│ │ ├── sft_ministral3_vl.ipynb
│ │ ├── sft_nemotron_3.ipynb
│ │ ├── sft_qwen_vl.ipynb
│ │ ├── sft_tool_calling.ipynb
│ │ └── sft_trl_lora_qlora.ipynb
│ └── scripts/
│ ├── async_grpo.py
│ ├── bco.py
│ ├── cpo.py
│ ├── dpo.py
│ ├── dpo_vlm.py
│ ├── evals/
│ │ └── judge_tldr.py
│ ├── gkd.py
│ ├── grpo_2048.py
│ ├── grpo_agent.py
│ ├── grpo_vlm.py
│ ├── gspo.py
│ ├── gspo_vlm.py
│ ├── kto.py
│ ├── mpo_vlm.py
│ ├── nash_md.py
│ ├── nemo_gym/
│ │ ├── README.md
│ │ ├── config.yaml
│ │ ├── deepspeed_zero3.yaml
│ │ ├── submit.sh
│ │ └── train_multi_environment.py
│ ├── online_dpo.py
│ ├── online_dpo_vlm.py
│ ├── openenv/
│ │ ├── browsergym.py
│ │ ├── browsergym_llm.py
│ │ ├── carla.py
│ │ ├── catch.py
│ │ ├── echo.py
│ │ ├── sudoku.py
│ │ ├── sudoku_prompt.txt
│ │ └── wordle.py
│ ├── orpo.py
│ ├── ppo/
│ │ ├── ppo.py
│ │ └── ppo_tldr.py
│ ├── prm.py
│ ├── reward_modeling.py
│ ├── rloo.py
│ ├── rloo_vlm.py
│ ├── sft.py
│ ├── sft_gemma3.py
│ ├── sft_gpt_oss.py
│ ├── sft_nemotron_3.py
│ ├── sft_tiny_aya_tool_calling.py
│ ├── sft_video_llm.py
│ ├── sft_vlm.py
│ ├── sft_vlm_gemma3.py
│ ├── tiny_aya_chat_template.jinja
│ └── xpo.py
├── pyproject.toml
├── requirements.txt
├── scripts/
│ ├── add_copyrights.py
│ ├── generate_harmony_dataset.py
│ ├── generate_tiny_models.py
│ ├── generate_toolcall_dataset.py
│ ├── generate_zen_dataset.py
│ ├── generate_zen_image_dataset.py
│ ├── generate_zen_multi_image_dataset.py
│ └── log_reports.py
├── tests/
│ ├── __init__.py
│ ├── conftest.py
│ ├── data/
│ │ └── template.jinja
│ ├── distributed/
│ │ ├── __init__.py
│ │ ├── data/
│ │ │ └── accelerate_configs/
│ │ │ ├── ddp.yaml
│ │ │ ├── fsdp2.yaml
│ │ │ ├── zero2.yaml
│ │ │ └── zero3.yaml
│ │ └── test_distributed.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ ├── test_async_grpo_trainer.py
│ │ ├── test_bco_trainer.py
│ │ ├── test_cpo_trainer.py
│ │ ├── test_dppo_trainer.py
│ │ ├── test_gkd_trainer.py
│ │ ├── test_gold_trainer.py
│ │ ├── test_grpo_with_replay_buffer_trainer.py
│ │ ├── test_gspo_token_trainer.py
│ │ ├── test_judges.py
│ │ ├── test_kto_trainer.py
│ │ ├── test_merge_model_callback.py
│ │ ├── test_minillm_trainer.py
│ │ ├── test_modeling_value_head.py
│ │ ├── test_nash_md_trainer.py
│ │ ├── test_online_dpo_trainer.py
│ │ ├── test_orpo_trainer.py
│ │ ├── test_ppo_trainer.py
│ │ ├── test_prm_trainer.py
│ │ ├── test_utils.py
│ │ ├── test_winrate_callback.py
│ │ ├── test_xpo_trainer.py
│ │ └── testing_utils.py
│ ├── test_activation_offloading.py
│ ├── test_callbacks.py
│ ├── test_chat_template_utils.py
│ ├── test_cli.py
│ ├── test_cli_utils.py
│ ├── test_data_utils.py
│ ├── test_dpo_trainer.py
│ ├── test_grpo_trainer.py
│ ├── test_model_utils.py
│ ├── test_reward_trainer.py
│ ├── test_rewards.py
│ ├── test_rich_progress_callback.py
│ ├── test_rloo_trainer.py
│ ├── test_sft_trainer.py
│ ├── test_skills.py
│ ├── test_skills_cli.py
│ ├── test_utils.py
│ ├── test_vllm_client_server.py
│ ├── testing_constants.py
│ └── testing_utils.py
└── trl/
├── __init__.py
├── _compat.py
├── _lazy_module.py
├── accelerate_configs/
│ ├── fsdp1.yaml
│ ├── fsdp2.yaml
│ ├── multi_gpu.yaml
│ ├── single_gpu.yaml
│ ├── zero1.yaml
│ ├── zero2.yaml
│ └── zero3.yaml
├── chat_template_utils.py
├── cli/
│ ├── __init__.py
│ ├── accelerate_config.py
│ ├── accelerate_launcher.py
│ ├── commands/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── env.py
│ │ ├── skills.py
│ │ ├── training.py
│ │ └── vllm_serve.py
│ └── main.py
├── data_utils.py
├── experimental/
│ ├── __init__.py
│ ├── async_grpo/
│ │ ├── __init__.py
│ │ ├── async_grpo_config.py
│ │ ├── async_grpo_trainer.py
│ │ └── async_rollout_worker.py
│ ├── bco/
│ │ ├── __init__.py
│ │ ├── bco_config.py
│ │ └── bco_trainer.py
│ ├── bema_for_ref_model/
│ │ ├── __init__.py
│ │ ├── callback.py
│ │ └── dpo_trainer.py
│ ├── cpo/
│ │ ├── __init__.py
│ │ ├── cpo_config.py
│ │ └── cpo_trainer.py
│ ├── dppo/
│ │ ├── __init__.py
│ │ ├── dppo_config.py
│ │ └── dppo_trainer.py
│ ├── gfpo/
│ │ ├── __init__.py
│ │ ├── gfpo_config.py
│ │ └── gfpo_trainer.py
│ ├── gkd/
│ │ ├── __init__.py
│ │ ├── gkd_config.py
│ │ └── gkd_trainer.py
│ ├── gold/
│ │ ├── __init__.py
│ │ ├── gold.py
│ │ ├── gold_config.py
│ │ └── gold_trainer.py
│ ├── grpo_with_replay_buffer/
│ │ ├── __init__.py
│ │ ├── grpo_with_replay_buffer_config.py
│ │ └── grpo_with_replay_buffer_trainer.py
│ ├── gspo_token/
│ │ ├── __init__.py
│ │ └── grpo_trainer.py
│ ├── judges/
│ │ ├── __init__.py
│ │ └── judges.py
│ ├── kto/
│ │ ├── __init__.py
│ │ ├── kto_config.py
│ │ └── kto_trainer.py
│ ├── merge_model_callback.py
│ ├── minillm/
│ │ ├── __init__.py
│ │ ├── minillm_config.py
│ │ └── minillm_trainer.py
│ ├── nash_md/
│ │ ├── __init__.py
│ │ ├── nash_md_config.py
│ │ └── nash_md_trainer.py
│ ├── online_dpo/
│ │ ├── __init__.py
│ │ ├── online_dpo_config.py
│ │ └── online_dpo_trainer.py
│ ├── openenv/
│ │ ├── __init__.py
│ │ └── utils.py
│ ├── orpo/
│ │ ├── __init__.py
│ │ ├── orpo_config.py
│ │ └── orpo_trainer.py
│ ├── papo/
│ │ ├── __init__.py
│ │ ├── papo_config.py
│ │ └── papo_trainer.py
│ ├── ppo/
│ │ ├── __init__.py
│ │ ├── modeling_value_head.py
│ │ ├── ppo_config.py
│ │ └── ppo_trainer.py
│ ├── prm/
│ │ ├── __init__.py
│ │ ├── prm_config.py
│ │ └── prm_trainer.py
│ ├── utils.py
│ ├── winrate_callback.py
│ └── xpo/
│ ├── __init__.py
│ ├── xpo_config.py
│ └── xpo_trainer.py
├── extras/
│ ├── __init__.py
│ ├── dataset_formatting.py
│ └── profiling.py
├── generation/
│ ├── __init__.py
│ ├── vllm_client.py
│ └── vllm_generation.py
├── import_utils.py
├── models/
│ ├── __init__.py
│ ├── activation_offloading.py
│ └── utils.py
├── py.typed
├── rewards/
│ ├── __init__.py
│ ├── accuracy_rewards.py
│ ├── format_rewards.py
│ └── other_rewards.py
├── scripts/
│ ├── __init__.py
│ ├── _hf_argparser.py
│ ├── dpo.py
│ ├── env.py
│ ├── grpo.py
│ ├── kto.py
│ ├── reward.py
│ ├── rloo.py
│ ├── sft.py
│ ├── utils.py
│ └── vllm_serve.py
├── skills/
│ ├── __init__.py
│ ├── cli.py
│ ├── skills.py
│ └── trl-training/
│ └── SKILL.md
├── templates/
│ ├── completions_dataset_card.md
│ ├── lm_model_card.md
│ └── rm_model_card.md
└── trainer/
├── __init__.py
├── base_config.py
├── base_trainer.py
├── callbacks.py
├── dpo_config.py
├── dpo_trainer.py
├── grpo_config.py
├── grpo_trainer.py
├── kto_config.py
├── kto_trainer.py
├── model_config.py
├── reward_config.py
├── reward_trainer.py
├── rloo_config.py
├── rloo_trainer.py
├── sft_config.py
├── sft_trainer.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug-report.yml
================================================
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve TRL
labels: [ "bug" ]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report! 🤗
🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
If you have code snippets, error messages, stack traces please provide them here as well.
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
value: |
```python
from trl import ...
```
outputs:
```
Traceback (most recent call last):
File "example.py", line 42, in <module>
...
```
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ...
You can get this information by running `trl env` in your terminal.
placeholder: Copy-paste the output of `trl env`
validations:
required: true
- type: checkboxes
id: terms
attributes:
label: Checklist
description: |
Before submitting, please confirm that you've completed each of the following.
If an item doesn't apply to your issue, check it anyway to show you've reviewed it.
options:
- label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))"
required: true
- label: "I have included my system information"
required: true
- label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
required: true
- label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
required: true
- label: "Any traceback provided is complete"
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/feature-request.yml
================================================
name: "\U0001F680 Feature request"
description: Submit a proposal/request for a new TRL feature
labels: [ "Feature request" ]
body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request
description: |
A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation
description: |
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
- type: textarea
id: contribution
validations:
required: true
attributes:
label: Your contribution
description: |
Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md)
================================================
FILE: .github/ISSUE_TEMPLATE/new-trainer-addition.yml
================================================
name: "\U0001F31F New trainer addition"
description: Submit a proposal/request to implement a new trainer for a post-training method
labels: [ "New trainer" ]
body:
- type: textarea
id: description-request
validations:
required: true
attributes:
label: Method description
description: |
Put any and all important information relative to the method
- type: checkboxes
id: information-tasks
attributes:
label: Open source status
description: |
Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`.
options:
- label: "The method implementation is available"
- label: "The model weights are available"
- label: "The training datasets are available"
- type: textarea
id: additional-info
attributes:
label: Provide useful links for the implementation
description: |
Please provide information regarding the implementation, the weights, and the authors.
Please mention the authors by @gh-username if you're aware of their usernames.
================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet though.
Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
Once you're done, someone will review your PR shortly. They may suggest changes to make the code even better.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request),
Pull Request section?
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
================================================
FILE: .github/codeql/custom-queries.qls
================================================
import codeql
from WorkflowString interpolation, Workflow workflow
where
interpolation.getStringValue().matches("${{ github.event.issue.title }}") or
interpolation.getStringValue().matches("${{ github.event.issue.body }}") or
interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or
interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or
interpolation.getStringValue().matches("${{ github.event.review.body }}") or
interpolation.getStringValue().matches("${{ github.event.comment.body }}") or
interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or
interpolation.getStringValue().matches("${{ github.event.head_commit.message }}")
interpolation.getStringValue().matches("${{ github.event.* }}") and
(
step.getKey() = "run" or // Injection in run
step.getKey() = "env" or // Injection via env
step.getKey() = "with" // Injection via with
)
select workflow, "🚨 Do not use directly as input of action"
================================================
FILE: .github/workflows/build_documentation.yml
================================================
name: Build documentation
on:
push:
branches:
- main
- doc-builder*
- v*-release
env:
TRL_EXPERIMENTAL_SILENCE: 1
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
with:
commit_sha: ${{ github.sha }}
package: trl
version_tag_suffix: ""
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
================================================
FILE: .github/workflows/build_pr_documentation.yml
================================================
name: Build PR Documentation
on:
pull_request:
env:
TRL_EXPERIMENTAL_SILENCE: 1
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build:
if: github.event.pull_request.draft == false
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: trl
version_tag_suffix: ""
================================================
FILE: .github/workflows/clear_cache.yml
================================================
name: "Cleanup Cache"
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
jobs:
cleanup:
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v6
- name: Cleanup
run: |
gh extension install actions/gh-actions-cache
REPO=${{ github.repository }}
echo "Fetching list of cache key"
cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 )
## Setting this to not fail the workflow while deleting cache keys.
set +e
echo "Deleting caches..."
for cacheKey in $cacheKeysForPR
do
gh actions-cache delete $cacheKey -R $REPO --confirm
done
echo "Done"
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
================================================
FILE: .github/workflows/codeQL.yml
================================================
name: "CodeQL Analysis - Workflows"
on:
workflow_dispatch:
jobs:
analyze:
name: "Analyze GitHub Workflows"
runs-on: ubuntu-latest
permissions:
security-events: write
actions: read
contents: read
steps:
- name: "Checkout repository"
uses: actions/checkout@v6
- name: "Initialize CodeQL"
uses: github/codeql-action/init@v2
with:
languages: "yaml"
queries: +security-and-quality, ./.github/codeql/custom-queries.qls
- name: "Perform CodeQL Analysis"
uses: github/codeql-action/analyze@v2
================================================
FILE: .github/workflows/docker-build.yml
================================================
name: Build TRL Docker image
on:
push:
branches:
- main
workflow_dispatch:
concurrency:
group: docker-image-builds
cancel-in-progress: false
jobs:
trl:
name: "Build and push TRL Docker image"
runs-on:
group: aws-general-8-plus
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: Get TRL version from PyPI
run: |
VERSION=$(curl -s https://pypi.org/pypi/trl/json | jq -r .info.version)
echo "VERSION=$VERSION" >> $GITHUB_ENV
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push
uses: docker/build-push-action@v6
with:
context: docker/trl
push: true
tags: |
huggingface/trl:${{ env.VERSION }}
huggingface/trl
- name: Post to Slack
if: always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
title: 🤗 Results of the TRL Dev Docker Image build
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
trl-dev:
name: "Build and push TRL Dev Docker image"
runs-on:
group: aws-general-8-plus
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push
uses: docker/build-push-action@v6
with:
context: docker/trl-dev
push: true
tags: |
huggingface/trl:dev
- name: Post to Slack
if: always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
title: 🤗 Results of the TRL Dev Docker Image build
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
================================================
FILE: .github/workflows/issue_auto_labeller.yml
================================================
name: "Hugging Face Issue Labeler"
on:
issues:
types: opened
jobs:
triage:
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- uses: actions/checkout@v6
- uses: August-murr/auto-labeler@0.0.1
with:
hf-api-key: ${{ secrets.CI_HF_API_TOKEN }}
================================================
FILE: .github/workflows/pr_style_bot.yml
================================================
name: PR Style Bot
on:
workflow_dispatch:
permissions:
contents: write
pull-requests: write
jobs:
run-style-bot:
if: >
contains(github.event.comment.body, '@bot /style') &&
github.event.issue.pull_request != null
runs-on: ubuntu-latest
steps:
- name: Extract PR details
id: pr_info
uses: actions/github-script@v8
with:
script: |
const prNumber = context.payload.issue.number;
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber
});
// We capture both the branch ref and the "full_name" of the head repo
// so that we can check out the correct repository & branch (including forks).
core.setOutput("prNumber", prNumber);
core.setOutput("headRef", pr.head.ref);
core.setOutput("headRepoFullName", pr.head.repo.full_name);
- name: Check out PR branch
uses: actions/checkout@v6
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
with:
# Instead of checking out the base repo, use the contributor's repo name
repository: ${{ env.HEADREPOFULLNAME }}
ref: ${{ env.HEADREF }}
# You may need fetch-depth: 0 for being able to push
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Debug
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
run: |
echo "PR number: ${{ env.PRNUMBER }}"
echo "Head Ref: ${{ env.HEADREF }}"
echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
- name: Set up Python
uses: actions/setup-python@v6
- name: Install dependencies
run: |
pip install ruff pre-commit
- name: Download Makefile from main branch
run: |
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile
- name: Compare Makefiles
run: |
if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
exit 1
fi
echo "No changes in Makefile. Proceeding..."
rm -rf main_Makefile
- name: Run make style and make quality
run: |
make precommit || true
- name: Commit and push changes
id: commit_and_push
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
# Configure git with the Actions bot user
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
# Make sure your 'origin' remote is set to the contributor's fork
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
# If there are changes after running style/quality, commit them
if [ -n "$(git status --porcelain)" ]; then
git add .
git commit -m "Apply style fixes"
# Push to the original contributor's forked branch
git push origin HEAD:${{ env.HEADREF }}
echo "changes_pushed=true" >> $GITHUB_OUTPUT
else
echo "No changes to commit."
echo "changes_pushed=false" >> $GITHUB_OUTPUT
fi
- name: Comment on PR with workflow run link
if: steps.commit_and_push.outputs.changes_pushed == 'true'
uses: actions/github-script@v8
with:
script: |
const prNumber = parseInt(process.env.prNumber, 10);
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
});
env:
prNumber: ${{ steps.pr_info.outputs.prNumber }}
================================================
FILE: .github/workflows/publish.yml
================================================
name: Publish to PyPI
on:
push:
branches:
- main
- v*-release
paths:
- "VERSION"
jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Read version
id: get_version
run: echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
- name: Debug - Show version.txt content
run: echo "Version is ${{ steps.get_version.outputs.version }}"
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build twine
- name: Build package
run: python -m build
- name: Publish to PyPI
if: ${{ !contains(steps.get_version.outputs.version, 'dev') }}
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
python -m twine upload dist/*
================================================
FILE: .github/workflows/slow-tests.yml
================================================
name: Slow tests (on push)
on:
push:
branches: [main]
paths:
# Run only when python files are modified
- "trl/**.py"
- "examples/**.py"
env:
RUN_SLOW: "yes"
IS_GITHUB_CI: "1"
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
TRL_EXPERIMENTAL_SILENCE: 1
jobs:
run_all_tests_single_gpu:
runs-on:
group: aws-g4dn-2xlarge
env:
CUDA_VISIBLE_DEVICES: "0"
TEST_TYPE: "single_gpu"
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all --shm-size "16gb"
defaults:
run:
shell: bash
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Install system dependencies
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install pytest-reportlog
- name: Run slow SFT tests on single GPU
if: always()
run: |
source .venv/bin/activate
make slow_tests
- name: Generate Report
if: always()
run: |
source .venv/bin/activate
uv pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
run_all_tests_multi_gpu:
runs-on:
group: aws-g4dn-2xlarge
env:
CUDA_VISIBLE_DEVICES: "0,1"
TEST_TYPE: "multi_gpu"
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all --shm-size "16gb"
defaults:
run:
shell: bash
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Install system dependencies
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install pytest-reportlog
- name: Run slow SFT tests on Multi GPU
if: always()
run: |
source .venv/bin/activate
make slow_tests
- name: Generate Reports
if: always()
run: |
source .venv/bin/activate
uv pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
rm *.txt
================================================
FILE: .github/workflows/tests-experimental.yml
================================================
name: Tests (experimental)
on:
pull_request:
paths:
# Run only when relevant files are modified
- "trl/experimental/**"
- "tests/experimental/**"
env:
TQDM_DISABLE: 1
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
PYTORCH_ALLOC_CONF: "expandable_segments:True"
TRL_EXPERIMENTAL_SILENCE: 1
jobs:
check_code_quality:
name: Check code quality
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v6
- name: Set up Python 3.13
uses: actions/setup-python@v6
with:
python-version: 3.13
- uses: pre-commit/action@v3.0.1
with:
extra_args: --all-files
tests:
name: Tests (experimental)
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Set up Python 3.13
uses: actions/setup-python@v6
with:
python-version: 3.13
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test_experimental
================================================
FILE: .github/workflows/tests.yml
================================================
name: Tests
on:
push:
branches:
- main
- ci-*
pull_request:
paths:
# Run only when relevant files are modified
- ".github/**.yml"
- "examples/**.py"
- "scripts/**.py"
- "tests/**.py"
- "trl/**.py"
- "pyproject.toml"
# Exclude if only experimental code/tests
- "!trl/experimental/**"
- "!tests/experimental/**"
env:
TQDM_DISABLE: 1
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
PYTORCH_ALLOC_CONF: "expandable_segments:True"
jobs:
check_code_quality:
name: Check code quality
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v6
- name: Set up Python 3.12
uses: actions/setup-python@v6
with:
python-version: 3.12
- uses: pre-commit/action@v3.0.1
with:
extra_args: --all-files
tests:
name: Tests
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
fail-fast: false
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python ${{ matrix.python-version }} and latest dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_dev:
name: Tests with dev dependencies
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Set up Python 3.12
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install -U git+https://github.com/huggingface/accelerate.git
uv pip install -U git+https://github.com/huggingface/datasets.git
uv pip install -U git+https://github.com/huggingface/transformers.git
uv pip install -U git+https://github.com/huggingface/peft.git
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python 3.12 and dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_wo_optional_deps:
name: Tests without optional dependencies
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Set up Python 3.12
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[test]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python 3.12 without optional dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_min_versions:
name: Tests with minimum versions
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Set up Python 3.12
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install accelerate==1.4.0
uv pip install datasets==3.0.0
uv pip install transformers==4.56.2
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python 3.12 and minimum dependencies versions
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
distributed_smoke:
name: Distributed smoke tests
runs-on:
group: aws-g5-12xlarge-cache
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
env:
CUDA_VISIBLE_DEVICES: "0,1"
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Set up Python 3.12
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
- name: Run distributed smoke tests
run: |
source .venv/bin/activate
pytest -v tests/distributed/test_distributed.py
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results of distributed smoke tests
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
================================================
FILE: .github/workflows/tests_latest.yml
================================================
name: Tests latest TRL release with dev dependencies
on:
schedule:
- cron: '0 0 * * *' # Runs daily at midnight UTC
workflow_dispatch:
env:
TQDM_DISABLE: 1
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
TRL_EXPERIMENTAL_SILENCE: 1
jobs:
tests:
name: Tests latest TRL release with dev dependencies
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
steps:
- name: Git checkout
uses: actions/checkout@v6
with: { ref: v0.29-release }
- name: Set up Python 3.12
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install -U git+https://github.com/huggingface/accelerate.git
uv pip install -U git+https://github.com/huggingface/datasets.git
uv pip install -U git+https://github.com/huggingface/transformers.git
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results of latest TRL with Python 3.12 and dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
================================================
FILE: .github/workflows/tests_transformers_branch.yml
================================================
name: Tests against Transformers branch
on:
workflow_dispatch:
inputs:
transformers_ref:
description: "Transformers git ref (branch, tag, or commit SHA)"
required: true
default: "main"
env:
TQDM_DISABLE: 1
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
PYTORCH_ALLOC_CONF: "expandable_segments:True"
jobs:
tests_transformers_branch:
name: Tests with Transformers ${{ inputs.transformers_ref }}
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Set up Python 3.12
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install -U git+https://github.com/huggingface/transformers.git@${{ inputs.transformers_ref }}
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Transformers ${{ inputs.transformers_ref }}
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
distributed_smoke:
name: Distributed smoke tests with Transformers ${{ inputs.transformers_ref }}
runs-on:
group: aws-g5-12xlarge-cache
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
env:
CUDA_VISIBLE_DEVICES: "0,1"
steps:
- name: Git checkout
uses: actions/checkout@v6
- name: Set up Python 3.12
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install -U git+https://github.com/huggingface/transformers.git@${{ inputs.transformers_ref }}
- name: Run distributed smoke tests
run: |
source .venv/bin/activate
pytest -v tests/distributed/test_distributed.py
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results of distributed smoke tests with Transformers ${{ inputs.transformers_ref }}
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
================================================
FILE: .github/workflows/trufflehog.yml
================================================
on:
push:
name: Secret Leaks
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@v3.93.1
with:
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
extra_args: --results=verified,unknown --exclude-detectors=postgres
================================================
FILE: .github/workflows/upload_pr_documentation.yml
================================================
name: Upload PR Documentation
on:
workflow_run:
workflows: ["Build PR Documentation"]
types:
- completed
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: trl
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
================================================
FILE: .gitignore
================================================
*.bak
.gitattributes
.last_checked
.gitconfig
*.bak
*.log
*~
~*
_tmp*
tmp*
tags
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.vscode
*.swp
# osx generated files
.DS_Store
.DS_Store?
.Trashes
ehthumbs.db
Thumbs.db
.idea
# pytest
.pytest_cache
# tools/trust-doc-nbs
docs_src/.last_checked
# symlinks to fastai
docs_src/fastai
tools/fastai
# link checker
checklink/cookies.txt
# .gitconfig is now autogenerated
.gitconfig
# wandb files
nbs/wandb/
examples/notebooks/wandb/
wandb/
# uv
uv.lock
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.13.3
hooks:
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format
types_or: [ python, pyi ]
# - repo: https://github.com/codespell-project/codespell
# rev: v2.1.0
# hooks:
# - id: codespell
# args:
# - --ignore-words-list=nd,reacher,thist,ths,magent,ba
# - --skip=docs/css/termynal.css,docs/js/termynal.js
================================================
FILE: AGENTS.md
================================================
# AGENTS.md
## Repository-specific guidance
### Main code vs experimental code
The repository is separated into **main code** and **experimental code**.
* **Main code** should remain stable, consistent, and well-tested.
* **Experimental code** may be less stable and may contain inconsistent patterns or limited testing.
Small non-invasive improvements that make experimental code more consistent with the main codebase are encouraged, but avoid large refactors.
### Paper implementations
If a PR implements a method, algorithm, or training approach from a research paper, it must also add a corresponding subsection to `paper_index.md`.
When reviewing such PRs, ensure that `paper_index.md` was updated.
### Code duplication and consistency
Trainers in this repository are **self-contained by design**. Shared logic (generation, reward computation, metric logging, weight syncing, etc.) is deliberately duplicated across trainers rather than abstracted into a shared base class.
This is intentional: each trainer must be readable, modifiable, and evolvable in isolation. The base class (`_BaseTrainer`) provides only minimal utilities (model card generation). Everything else — vLLM generation paths, `_get_per_token_logps_and_entropies`, `_calculate_rewards`, `_prepare_inputs`, metric logging — is copied in full.
**The tradeoff**: duplication is accepted, but **consistency is mandatory**. When the same logic appears in multiple trainers, the duplicated blocks must stay aligned:
- Same variable names (`self._last_loaded_step`, `self._metrics[mode]`, …)
- Same control flow structure (if/elif/else branches in the same order)
- Same comments (word-for-word when the logic is identical)
- Divergences only where the trainer's semantics require it (e.g., GRPO extracts logprobs from vLLM, RLOO discards them)
**Consistency over correctness**: this is a strong requirement. When duplicating code, reproduce it exactly — even if you believe the original has a bug. Do not silently fix the issue in your copy. Instead, keep your copy consistent with the source and report the problem so it can be fixed across all trainers in a dedicated PR. A correct-but-inconsistent codebase is harder to maintain than a consistently-wrong one that can be fixed in a single sweep.
**When modifying duplicated code**: if you change a pattern that exists in multiple trainers (e.g., the vLLM generation path in `_generate_single_turn`), apply the same change to all other trainers. A fix in GRPO often implies the same fix in RLOO, and vice versa. Not propagating a change is a bug.
**When reviewing**: if a PR touches duplicated logic, verify that all copies are updated consistently. A common mistake is fixing one trainer and forgetting the others.
### Simplicity
This codebase values **leanness and simplicity above all**. Prefer straightforward, inline code over abstractions, helpers, or utilities — even at the cost of some robustness or generality.
Concretely:
- Do not add layers of indirection (registries, factory patterns, plugin systems). A contributor should be able to read a trainer top to bottom and understand the full flow.
- Prefer a simple implementation that covers 90% of cases over a complex one that covers 100%. A function that handles the common path in 20 lines is better than a catch-all that handles every edge case in 80.
- Do not add defensive code, fallback paths, or configuration options "just in case". Only handle cases that actually exist today.
- Avoid `hasattr` and `getattr`. Their use is almost always a symptom of overly defensive programming or a disguised version check (e.g., "this attribute was added in version X"). Instead, either drop the conditional entirely or express the version check explicitly with a version comparison. There is nearly always a cleaner alternative.
- When in doubt, prefer less code. Every new function, parameter, or branch is maintenance burden. The best abstraction is often no abstraction.
## Documentation
### Docstrings
Docstrings must follow the repository format below. Do **not** convert docstrings to other styles (Google, NumPy, etc.).
Rules:
* Types appear in backticks inside parentheses: (`str`)
* Optional parameters are marked with `*optional*`
* Defaults are written as: `defaults to <value>`
* When the default is `None`, prefer ```(`str`, *optional*)``` instead of ```(`str` or `None`, *optional*, defaults to `None`)```
* Union types use `or`: `str` or `None`
* References to classes use the format: [`~transformers.PreTrainedModel`]
* Class docstrings may group parameters using headers such as: `> Parameters for X:`
Example:
````python
def method(self, param1: str, param2: int = 1, param3: float | None = None):
"""
Brief one-line description of what this does.
Args:
param1 (`str`):
Description of required param.
param2 (`int`, *optional*, defaults to `1`):
Description of optional param with default.
param3 (`float`, *optional*):
Description of optional param without explicit default.
Returns:
`dict` with keys:
- `key1` (`list[int]`):
Description of this key.
Examples:
```python
>>> my_func("hello")
```
"""
````
### Links to papers
When linking to papers, use `https://huggingface.co/papers/<id>` instead of `https://arxiv.org/abs/<id>` (same ID suffix system).
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
title: 'TRL: Transformers Reinforcement Learning'
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Leandro
family-names: von Werra
- given-names: Younes
family-names: Belkada
- given-names: Lewis
family-names: Tunstall
- given-names: Edward
family-names: Beeching
- given-names: Tristan
family-names: Thrush
- given-names: Nathan
family-names: Lambert
- given-names: Shengyi
family-names: Huang
- given-names: Kashif
family-names: Rasul
- given-names: Quentin
family-names: Gallouédec
repository-code: 'https://github.com/huggingface/trl'
abstract: >-
TRL (Transformers Reinforcement Learning) is an
open-source toolkit for aligning transformer models via
post-training. It provides practical, scalable
implementations of SFT, reward modeling, DPO, and GRPO
within the Hugging Face ecosystem.
keywords:
- transformers
- reinforcement learning
- preference optimization
- language model alignment
- post-training
license: Apache-2.0
version: '0.29'
date-released: '2020-03-27'
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
feedback@huggingface.co.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations
================================================
FILE: CONTRIBUTING.md
================================================
# How to contribute to TRL?
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
## Ways to contribute
There are several ways you can contribute to TRL:
* Fix outstanding issues with the existing code.
* Submit issues related to bugs or desired new features.
* Implement trainers for new post-training algorithms.
* Contribute to the examples or the documentation.
If you don't know where to start, there is a special [Good First Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/%F0%9F%A7%92%20good%20second%20issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
> All contributions are equally valuable to the community. 🥰
Before you start contributing make sure you have installed all the dev tools:
```bash
pip install -e .[dev]
```
## Fixing outstanding issues
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
## Submitting a bug-related issue or feature request
Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback.
### Did you find a bug?
The TRL library is robust and reliable thanks to users who report the problems they encounter.
Before you report an issue, we would really appreciate it if you could **make sure the bug was not already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
* A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s.
* The *full* traceback if an exception is raised.
* Attach any other additional information, like screenshots, you think may help.
To get the OS and software versions automatically, run the following command:
```bash
trl env
```
### Do you want a new feature?
If there is a new feature you'd like to see in TRL, please open an issue and describe:
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?
Whatever it is, we'd love to hear about it!
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
3. Provide a *code snippet* that demonstrates the feature's usage.
4. If the feature is related to a paper, please include a link.
If your issue is well written we're already 80% of the way there by the time you create it.
## Do you want to implement a new trainer?
New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL:
* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods.
* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM.
Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
* A short description of the method and a link to the paper.
* Link to the implementation if it is open-sourced.
* Link to model weights trained with the method if they are available.
Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration:
* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py)
* RL-based optimisation: [`rloo_trainer.py`](./trl/trainer/rloo_trainer.py) and [`rloo_config.py`](./trl/trainer/rloo_config.py)
* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py)
## Do you want to add documentation?
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested!
## Submitting a pull request (PR)
Before writing code, we strongly advise you to search through the existing PRs or issues to make sure that nobody is already working on the same thing. If you are unsure, it is always a good idea to open an issue to get some feedback.
You will need basic `git` proficiency to be able to contribute to TRL. `git` is not the easiest tool to use but it has the greatest manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing:
1. Fork the [repository](https://github.com/huggingface/trl) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote. The following command assumes you have your public SSH key uploaded to GitHub. See the following guide for more [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
```bash
git clone git@github.com:<your Github handle>/trl.git
cd trl
git remote add upstream https://github.com/huggingface/trl.git
```
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
```bash
git checkout main
git fetch upstream
git merge upstream/main
```
Once your `main` branch is synchronized, create a new branch from it:
```bash
git checkout -b a-descriptive-name-for-my-changes
```
**Do not** work on the `main` branch.
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
```bash
pip install -e .[dev]
```
(If TRL was already installed in the virtual environment, remove it with `pip uninstall trl` before reinstalling it.)
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using the provided Dev Container. Check [the documentation on how to get started with dev containers](https://code.visualstudio.com/docs/remote/containers).
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite passes. You should run the tests impacted by your changes like this (see below an explanation regarding the environment variable):
```bash
pytest tests/<TEST_TO_RUN>.py
```
> For the following commands leveraging the `make` utility.
You can also run the full suite with the following command.
```bash
make test
```
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR.
To apply these checks and corrections in one step, use:
```bash
make precommit
```
This command runs the following:
* Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
* Runs additional scripts such as adding copyright information.
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
Once you're happy with your changes, add changed files using `git add` and make a commit with `git commit` to record your changes locally:
```bash
git add modified_file.py
git commit
```
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
It is a good idea to sync your copy of the code with the original
repository regularly. This way you can quickly account for changes:
```bash
git fetch upstream
git rebase upstream/main
```
Push the changes to your account using:
```bash
git push -u origin a-descriptive-name-for-my-changes
```
6. Once you are satisfied (**and the checklist below is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
### Checklist
1. The title of your pull request should be a summary of its contribution;
2. If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people consulting the issue know you are working on it);
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged;
4. Make sure existing tests pass;
5. Add high-coverage tests. No quality testing = no merge.
### Tests
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
We use `pytest` to run the tests. From the root of the
repository here's how to run tests with `pytest` for the library:
```bash
python -m pytest -sv ./tests
```
That's how `make test` is implemented (without the `pip install` line)!
You can specify a smaller set of tests to test only the feature
you're working on.
### Default values guidelines
1. **Use defaults when appropriate**:
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
2. **Prioritize proven defaults**:
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
3. **Ensure safety and predictability**:
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
4. **Balance consistency and flexibility**:
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
5. **Opt-in for new features**:
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
### Writing documentation
High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
To illustrate what good documentation looks like, here’s an example of a well-documented function:
````python
def replicate_str(string: str, n: int, sep: str = " ") -> str:
r"""
Replicate a string `n` times with a separator.
Args:
string (`str`):
String to replicate.
n (`int`):
Number of times to replicate the string.
sep (`str`, *optional*, defaults to `" "`):
Separator to use between each replication.
Returns:
`str`: The replicated string.
Examples:
```python
>>> replicate_str("hello", 3)
"hello hello hello"
>>> replicate_str("hello", 3, sep=", ")
"hello, hello, hello"
```
"""
return sep.join([string] * n)
````
* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
* **Type Annotations:**
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
* **String Defaults:**
* Ensured that default string values are wrapped in double quotes:
```txt
defaults to `"foo"`
```
* **Dictionary Typing:**
* Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
* **Default Value Formatting:**
* Consistently surrounded default values with backticks for improved formatting:
```txt
defaults to `4`
```
* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
```python
def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]:
r"""
Calculates basic statistics for a given dataset.
Args:
> Data inputs
data (`list[float]`):
A list of numerical values to analyze.
> Configuration parameters
precision (`int`, *optional*, defaults to `2`):
Number of decimal places to round the results.
include_variance (`bool`, *optional*, defaults to `False`):
Whether to include the variance of the dataset in the results.
Returns:
`dict[str, float]`:
A dictionary containing calculated statistics such as mean, median, and optionally variance.
"""
...
```
### Deprecation and backward compatibility
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
* **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
* **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
Example:
```python
warnings.warn(
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
"Please use the `Trainer.bar` class instead.",
FutureWarning,
stacklevel=2,
)
```
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
* **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
* **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
### Working with warnings
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
#### Definitions
* **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
* **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
#### Choosing the right message
* **Correct → No warning**:
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
* **Correct but deserves attention → No warning, possibly a log message**:
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
```python
logger.info("This is an informational message about a rare but correct operation.")
```
* **Correct but very likely a mistake → Warning with option to disable**:
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
```python
def my_function(foo, bar, _warn=True):
if foo == bar:
if _warn:
logger.warning("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
# Do something
```
* **Supported but not correct → Warning**:
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
```python
def my_function(foo, bar):
if foo and bar:
logger.warning("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
# Do something
```
* **Not supported → Exception**:
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
```python
def my_function(foo, bar):
if foo and bar:
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
```
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020-2026 The HuggingFace Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MANIFEST.in
================================================
include LICENSE
include CONTRIBUTING.md
include README.md
include trl/accelerate_configs/*.yaml
include trl/templates/*.md
include trl/skills/**/*.md
recursive-exclude * __pycache__
prune tests
================================================
FILE: MIGRATION.md
================================================
# Migrating from TRL v0 to v1
This guide covers the breaking changes introduced in TRL v1 and how to update your code. Most structural changes (trainers moved to experimental, removed model classes, etc.) already shipped in v0.29 — if you're already on v0.29, this migration is minimal.
## Changed defaults
| Config | Parameter | v0 default | v1 default | Action needed |
| --- | --- | --- | --- | --- |
| `GRPOConfig` | `vllm_mode` | `"server"` | `"colocate"` | If you use `use_vllm=True` without specifying `vllm_mode`, vLLM will now run in the same process instead of connecting to a separate server. Set `vllm_mode="server"` explicitly if you rely on server mode. |
| `RLOOConfig` | `vllm_mode` | `"server"` | `"colocate"` | Same as above. |
## Renamed options
| Config | Parameter | v0 value | v1 value | Action needed |
| --- | --- | --- | --- | --- |
| `SFTConfig` | `packing` | `"bfd-requeue"` | `"bfd_split"` | Replace `packing="bfd-requeue"` with `packing="bfd_split"`. The old value will still be accepted for a few versions but will be removed in a future release. |
## Migrating from an earlier version
Depending on which version you're migrating from, refer to the [release notes](https://github.com/huggingface/trl/releases) for v0.29 and earlier for version-specific changes.
================================================
FILE: Makefile
================================================
.PHONY: test precommit common_tests slow_tests tests_gpu test_experimental
check_dirs := examples tests trl
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
test:
pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests
precommit:
python scripts/add_copyrights.py
pre-commit run --all-files
doc-builder style trl tests docs/source --max_len 119
slow_tests:
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
test_experimental:
pytest -n auto -s -v tests/experimental
================================================
FILE: README.md
================================================
# TRL - Transformers Reinforcement Learning
<div style="text-align: center">
<picture>
<source media="(prefers-color-scheme: light)" srcset="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/TRL%20banner%20light.png">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png" alt="TRL Banner">
</picture>
</div>
<hr> <br>
<h3 align="center">
<p>A comprehensive library to post-train foundation models</p>
</h3>
<p align="center">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
</p>
## 🎉 What's New
**OpenEnv Integration:** TRL now supports **[OpenEnv](https://huggingface.co/blog/openenv)**, the open-source framework from Meta for defining, deploying, and interacting with environments in reinforcement learning and agentic workflows.
Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documentation](https://huggingface.co/docs/trl/openenv).
## Overview
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
## Highlights
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
- **Efficient and scalable**:
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
- Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
- Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
## Installation
### Python Package
Install the library using `pip`:
```bash
pip install trl
```
### From source
If you want to use the latest features before an official release, you can install TRL from source:
```bash
pip install git+https://github.com/huggingface/trl.git
```
### Repository
If you want to use the examples you can clone the repository with the following command:
```bash
git clone https://github.com/huggingface/trl.git
```
## Quick Start
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
### `SFTTrainer`
Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
```python
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
```
### `GRPOTrainer`
[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
```python
from datasets import load_dataset
from trl import GRPOTrainer
from trl.rewards import accuracy_reward
dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=accuracy_reward,
train_dataset=dataset,
)
trainer.train()
```
> [!NOTE]
> For reasoning models, use the `reasoning_accuracy_reward()` function for better results.
### `DPOTrainer`
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
```python
from datasets import load_dataset
from trl import DPOTrainer
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = DPOTrainer(
model="Qwen3/Qwen-0.6B",
train_dataset=dataset,
)
trainer.train()
```
### `RewardTrainer`
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
```python
from trl import RewardTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = RewardTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()
```
## Command Line Interface (CLI)
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
**SFT:**
```bash
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara \
--output_dir Qwen2.5-0.5B-SFT
```
**DPO:**
```bash
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name argilla/Capybara-Preferences \
--output_dir Qwen2.5-0.5B-DPO
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/clis) or use `--help` for more details.
## Development
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]
```
## Experimental
A minimal incubation area is available under `trl.experimental` for unstable / fast-evolving features. Anything there may change or be removed in any release without notice.
Example:
```python
from trl.experimental.new_trainer import NewTrainer
```
Read more in the [Experimental docs](https://huggingface.co/docs/trl/experimental_overview).
## Citation
```bibtex
@software{vonwerra2020trl,
title = {{TRL: Transformers Reinforcement Learning}},
author = {von Werra, Leandro and Belkada, Younes and Tunstall, Lewis and Beeching, Edward and Thrush, Tristan and Lambert, Nathan and Huang, Shengyi and Rasul, Kashif and Gallouédec, Quentin},
license = {Apache-2.0},
url = {https://github.com/huggingface/trl},
year = {2020}
}
```
## License
This repository's source code is available under the [Apache-2.0 License](LICENSE).
================================================
FILE: RELEASE.md
================================================
# Making a release
> [!NOTE]
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
## Major/Minor Release
### 1. Ensure your local repository is up to date with the upstream repository
```bash
git checkout main
git pull origin main
```
> [!WARNING]
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
### 2. Create a release branch from main
```bash
git checkout -b release-v{major}.{minor}
```
### 3. Change the version in the following files
- `.github/workflows/tests_latest.yml`:
```diff
- with: { ref: v{major}.{minor-1}-release }
+ with: { ref: v{major}.{minor}-release }
```
- `CITATION.cff`
```diff
- version: '{major}.{minor-1}'
+ version: '{major}.{minor}'
```
- `VERSION`
```diff
- {major}.{minor}.0.dev0
+ {major}.{minor}.0
```
### 4. Commit and push these changes
```shell
git add .github/workflows/tests_latest.yml CITATION.cff VERSION
git commit -m 'Release: {major}.{minor}'
git push origin release-v{major}.{minor}
```
### 5. Create a pull request
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
### 6. Once the pull request is approved, merge it into `main`
It will automatically publish the new version of the package on PyPI.
### 7. Add a tag in git to mark the release
```shell
git checkout main
git pull origin main
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
git push origin v{major}.{minor}.0
```
### 8. Create a branch `v{major}.{minor}-release` for future patch releases
```shell
git checkout -b v{major}.{minor}-release
git push origin v{major}.{minor}-release
```
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
### 9. Create a GitHub Release
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
2. Click **Draft a new release**.
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new.
5. Click **Publish Release**.
### 10. Bump to dev version
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
```shell
git checkout -b bump-dev-version-{major}.{minor+1}
```
2. Change the version in file `VERSION`:
```diff
- {major}.{minor}.0
+ {major}.{minor+1}.0.dev0
```
3. Commit and push these changes
```shell
git add VERSION
git commit -m '⬆️ Bump dev version'
git push origin bump-dev-version-{major}.{minor+1}
```
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
5. Once the pull request is approved, merge it into `main`.
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.
## Making a patch release
### 1. Ensure your local repository is up to date with the upstream repository
```bash
git checkout v{major}.{minor}-release
git pull origin main
```
### 2. Cherry-pick the changes you want to include in the patch release
```bash
git cherry-pick <commit-hash-0>
git cherry-pick <commit-hash-1>
...
```
### 3. Change the version in the file `VERSION`
```diff
- {major}.{minor}.{patch-1}
+ {major}.{minor}.{patch}
```
### 4. Commit and push these changes
```shell
git add VERSION
git commit -m 'Release: {major}.{minor}.{patch}'
git push origin v{major}.{minor}-release
```
### 5. Wait for the CI to pass
The CI will automatically publish the new version of the package on PyPI.
### 6. Add a tag in git to mark the release
```shell
git tag -a v{major}.{minor}.{patch} -m 'Adds tag v{major}.{minor}.{patch} for PyPI'
git push origin v{major}.{minor}.{patch}
```
#### 7. Create a GitHub Release
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
2. Click **Draft a new release**.
3. Select the `v{major}.{minor}.{patch}` tag you just created in step 7.
4. Add a title (`v{major}.{minor}.{patch}`) and a short description of what’s new.
5. Click **Publish Release**.
================================================
FILE: VERSION
================================================
1.0.0.dev0
================================================
FILE: docker/trl/Dockerfile
================================================
FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade pip uv
RUN uv pip install --system trl[liger,peft,vlm] kernels trackio
================================================
FILE: docker/trl-dev/Dockerfile
================================================
FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade pip uv
RUN uv pip install --system --no-cache "git+https://github.com/huggingface/trl.git#egg=trl[liger,peft,vlm]"
RUN uv pip install --system kernels liger_kernel peft trackio
================================================
FILE: docs/source/_toctree.yml
================================================
- sections:
- local: index
title: TRL
- local: installation
title: Installation
- local: quickstart
title: Quickstart
title: Getting started
- sections:
- local: dataset_formats
title: Dataset Formats
- local: paper_index
title: Paper Index
title: Conceptual Guides
- sections: # Sorted alphabetically
- local: dpo_trainer
title: DPO
- local: grpo_trainer
title: GRPO
- local: reward_trainer
title: Reward
- local: rloo_trainer
title: RLOO
- local: sft_trainer
title: SFT
title: Trainers
- sections:
- local: clis
title: Command Line Interface (CLI)
- local: jobs_training
title: Training using Jobs
- local: customization
title: Customizing the Training
- local: reducing_memory_usage
title: Reducing Memory Usage
- local: speeding_up_training
title: Speeding Up Training
- local: distributing_training
title: Distributing Training
- local: use_model
title: Using Trained Models
title: How-to guides
- sections:
- local: deepspeed_integration
title: DeepSpeed
- local: kernels_hub
title: Kernels Hub
- local: liger_kernel_integration
title: Liger Kernel
- local: peft_integration
title: PEFT
- local: ptt_integration
title: Post Training Toolkit
- local: rapidfire_integration
title: RapidFire AI
- local: trackio_integration
title: Trackio
- local: unsloth_integration
title: Unsloth
- local: vllm_integration
title: vLLM
title: Integrations
- sections:
- local: example_overview
title: Example Overview
- local: community_tutorials
title: Community Tutorials
- local: lora_without_regret
title: LoRA Without Regret
title: Examples
- sections:
- sections:
- local: chat_template_utils
title: Chat Template Utilities
- local: data_utils
title: Data Utilities
- local: script_utils
title: Script Utilities
title: Utilities
- local: callbacks
title: Callbacks
- local: rewards
title: Reward Functions
title: API
- sections:
- local: experimental_overview
title: Experimental Overview
- local: openenv
title: OpenEnv Integration
- local: async_grpo_trainer # Sorted alphabetically
title: Asynchronous GRPO
- local: bema_for_reference_model
title: BEMA for Reference Model
- local: bco_trainer
title: BCO
- local: cpo_trainer
title: CPO
- local: gfpo
title: GFPO
- local: gkd_trainer
title: GKD
- local: gold_trainer
title: GOLD
- local: grpo_with_replay_buffer
title: GRPO With Replay Buffer
- local: gspo_token
title: GSPO-token
- local: judges
title: Judges
- local: kto_trainer
title: KTO
- local: merge_model_callback
title: MergeModelCallback
- local: minillm_trainer
title: MiniLLM
- local: nash_md_trainer
title: Nash-MD
- local: nemo_gym
title: NeMo Gym
- local: online_dpo_trainer
title: Online DPO
- local: orpo_trainer
title: ORPO
- local: papo_trainer
title: PAPO
- local: ppo_trainer
title: PPO
- local: prm_trainer
title: PRM
- local: winrate_callback
title: WinRateCallback
- local: xpo_trainer
title: XPO
title: Experimental
================================================
FILE: docs/source/async_grpo_trainer.md
================================================
# Asynchronous GRPO
> [!IMPORTANT]
> This trainer requires `vllm>=0.17.1` and `transformers>=5.2.0`. For distributed training, only FSDP2 is supported (DeepSpeed ZeRO is not).
>
> Currently, `vllm` and `transformers` have conflicting dependency constraints. To work around this, install vLLM first and then force-install transformers:
>
> ```bash
> pip install 'vllm>=0.17.1'
> pip install 'transformers>=5.2.0' --no-deps
> ```
## Overview
[`AsyncGRPOTrainer`] implements the same [GRPO](grpo_trainer) algorithm but decouples rollout generation from training. A background worker continuously streams completions from a vLLM server while the training loop consumes them, so generation and gradient updates overlap instead of alternating. The API mirrors [`GRPOTrainer`] — for full details on the GRPO method itself (advantage computation, KL estimation, loss formulation, reward functions, etc.), see the [GRPO Trainer](grpo_trainer) documentation. Not all features from [`GRPOTrainer`] are available; refer to [`AsyncGRPOConfig`] for the supported parameters.
This trainer was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Amine Dirhoussi](https://huggingface.co/aminediroHF).
## How it differs from [`GRPOTrainer`]
In the standard [`GRPOTrainer`], generation and training are sequential: generate a batch, compute the loss, update weights, repeat. Even in [vLLM colocate mode](grpo_trainer#speed-up-training-with-vllm), where generation runs on the same GPUs, one phase must finish before the other begins.
[`AsyncGRPOTrainer`] separates these two concerns:
- **Rollout worker** (background thread) — sends prompts to a vLLM server, scores completions with reward functions, computes advantages, and pushes ready-to-train samples into a queue.
- **Training loop** (main process) — pulls samples from the queue, computes the clipped surrogate loss, and updates the model weights.
After every `weight_sync_steps` training steps, the updated weights are transferred to the vLLM server via NCCL so that subsequent generations reflect the latest policy.
Because generation and training run concurrently, the training samples may have been generated by a slightly older version of the model. The `max_staleness` parameter controls how many weight updates a sample can lag behind before being discarded.
The number of concurrent requests sent to the vLLM server is controlled by `max_inflight_tasks`. By default it is set automatically to `max_staleness × per_device_train_batch_size × gradient_accumulation_steps × num_processes` — the maximum number of samples the trainer can consume before they become stale. Generating more than this is wasteful since the excess samples will be discarded.
## Quick start
```python
# train_async_grpo.py
from datasets import load_dataset
from trl.experimental.async_grpo import AsyncGRPOTrainer
from trl.rewards import accuracy_reward
dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
trainer = AsyncGRPOTrainer(
model="Qwen/Qwen3-4B",
reward_funcs=accuracy_reward,
train_dataset=dataset,
)
trainer.train()
```
The vLLM server and the trainer must run on **separate GPUs**. Use `CUDA_VISIBLE_DEVICES` to partition your GPUs. For example, with 2 GPUs, you can run the vLLM server on GPU 0 and the trainer on GPU 1 as follows:
```bash
# Terminal 1: vLLM server on GPU 0 (dev mode + NCCL weight transfer are required)
CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \
--max-model-len 4096 \
--logprobs-mode processed_logprobs \
--weight-transfer-config '{"backend":"nccl"}'
```
> [!TIP]
> Set `--max-model-len` to the maximum total sequence length (prompt + completion) you expect. A lower value reduces GPU memory usage on the server, freeing more memory for the KV cache and increasing throughput. A good starting point is the prompt length plus `max_completion_length` from your config.
```bash
# Terminal 2: training on GPU 1
CUDA_VISIBLE_DEVICES=1 accelerate launch train_async_grpo.py
```
## Design philosophy
This trainer is intentionally kept minimal and is not meant to grow into a general-purpose solution. If you need a feature that is not supported, we recommend cloning the repository and adapting the trainer to your needs directly. New features will only be considered when there is significant community demand.
## AsyncGRPOConfig
[[autodoc]] trl.experimental.async_grpo.AsyncGRPOConfig
## AsyncGRPOTrainer
[[autodoc]] trl.experimental.async_grpo.AsyncGRPOTrainer
================================================
FILE: docs/source/bco_trainer.md
================================================
# BCO Trainer
[](https://huggingface.co/models?other=bco,trl)
TRL supports the Binary Classifier Optimization (BCO).
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
For a full example have a look at [`examples/scripts/bco.py`].
## Expected dataset type
The [`experimental.bco.BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
The [`experimental.bco.BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Expected model format
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `BCOTrainer`
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
```python
from trl.experimental.bco import BCOConfig, BCOTrainer
training_args = BCOConfig(
beta=0.1,
)
bco_trainer = BCOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
)
```
After this one can then call:
```python
bco_trainer.train()
```
## Underlying Distribution matching (UDM)
In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts.
Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts.
If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM.
Choose an embedding model and tokenizer:
```python
embedding_model = AutoModel.from_pretrained(your_model_id)
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
# customize this function depending on your embedding model
def embed_prompt(input_ids, attention_mask, model):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state.mean(dim=1)
embedding_model = Accelerator().prepare_model(self.embedding_model)
embedding_func = partial(embed_prompt, model=embedding_model)
```
Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
```python
training_args = BCOConfig(
beta=0.1,
prompt_sample_size=512,
)
bco_trainer = BCOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
embedding_func=embedding_func,
embedding_tokenizer=self.embedding_tokenizer,
)
bco_trainer.train()
```
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
## BCOTrainer
[[autodoc]] experimental.bco.BCOTrainer
- train
- save_model
- push_to_hub
## BCOConfig
[[autodoc]] experimental.bco.BCOConfig
================================================
FILE: docs/source/bema_for_reference_model.md
================================================
# BEMA for Reference Model
This feature implements the BEMA algorithm to update the reference model during DPO training.
## Usage
```python
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
bema_callback = BEMACallback(update_ref_model=True)
trainer = DPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
train_dataset=dataset,
callbacks=[bema_callback],
)
trainer.train()
```
## DPOTrainer
[[autodoc]] experimental.bema_for_ref_model.DPOTrainer
- train
- save_model
- push_to_hub
## BEMACallback
[[autodoc]] experimental.bema_for_ref_model.BEMACallback
================================================
FILE: docs/source/callbacks.md
================================================
# Callbacks
## RichProgressCallback
[[autodoc]] RichProgressCallback
## LogCompletionsCallback
[[autodoc]] LogCompletionsCallback
## BEMACallback
[[autodoc]] BEMACallback
## WeaveCallback
[[autodoc]] WeaveCallback
================================================
FILE: docs/source/chat_template_utils.md
================================================
# Chat template utilities
## clone_chat_template
[[autodoc]] clone_chat_template
## is_chat_template_prefix_preserving
[[autodoc]] chat_template_utils.is_chat_template_prefix_preserving
## get_training_chat_template
[[autodoc]] chat_template_utils.get_training_chat_template
================================================
FILE: docs/source/clis.md
================================================
# Command Line Interfaces (CLIs)
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
## Commands
Currently supported commands are:
### Training Commands
- `trl dpo`: fine-tune a LLM with DPO
- `trl grpo`: fine-tune a LLM with GRPO
- `trl kto`: fine-tune a LLM with KTO
- `trl reward`: train a Reward Model
- `trl rloo`: fine-tune a LLM with RLOO
- `trl sft`: fine-tune a LLM with SFT
### Other Commands
- `trl env`: get the system information
- `trl vllm-serve`: serve a model with vLLM
## Fine-Tuning with the TRL CLI
### Basic Usage
You can launch training directly from the CLI by specifying required arguments like the model and dataset:
<hfoptions id="trainer">
<hfoption id="SFT">
```bash
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb
```
</hfoption>
<hfoption id="DPO">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf
```
</hfoption>
<hfoption id="Reward">
```bash
trl reward \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/ultrafeedback_binarized
```
</hfoption>
<hfoption id="GRPO">
```bash
trl grpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name HuggingFaceH4/Polaris-Dataset-53K \
--reward_funcs accuracy_reward
```
</hfoption>
<hfoption id="RLOO">
```bash
trl rloo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name HuggingFaceH4/Polaris-Dataset-53K \
--reward_funcs accuracy_reward
```
</hfoption>
<hfoption id="KTO">
```bash
trl kto \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/kto-mix-14k
```
</hfoption>
</hfoptions>
### Using Configuration Files
To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
<hfoptions id="trainer">
<hfoption id="SFT">
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
<hfoption id="Reward">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/ultrafeedback_binarized
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption>
<hfoption id="GRPO">
```yaml
# grpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: HuggingFaceH4/Polaris-Dataset-53K
reward_funcs:
- accuracy_reward
```
Launch with:
```bash
trl grpo --config grpo_config.yaml
```
</hfoption>
<hfoption id="RLOO">
```yaml
# rloo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: HuggingFaceH4/Polaris-Dataset-53K
reward_funcs:
- accuracy_reward
```
Launch with:
```bash
trl rloo --config rloo_config.yaml
```
</hfoption>
<hfoption id="KTO">
```yaml
# kto_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/kto-mix-14k
```
Launch with:
```bash
trl kto --config kto_config.yaml
```
</hfoption>
</hfoptions>
### Scaling Up with Accelerate
TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
<hfoptions id="trainer">
<hfoption id="SFT">
```bash
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--num_processes 4
```
or, with a config file:
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
num_processes: 4
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--num_processes 4
```
or, with a config file:
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
num_processes: 4
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
<hfoption id="Reward">
```bash
trl reward \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_processes 4
```
or, with a config file:
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/ultrafeedback_binarized
num_processes: 4
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption>
<hfoption id="GRPO">
```bash
trl grpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name HuggingFaceH4/Polaris-Dataset-53K \
--reward_funcs accuracy_reward \
--num_processes 4
```
or, with a config file:
```yaml
# grpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: HuggingFaceH4/Polaris-Dataset-53K
reward_funcs:
- accuracy_reward
num_processes: 4
```
Launch with:
```bash
trl grpo --config grpo_config.yaml
```
</hfoption>
<hfoption id="RLOO">
```bash
trl rloo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name HuggingFaceH4/Polaris-Dataset-53K \
--reward_funcs accuracy_reward \
--num_processes 4
```
or, with a config file:
```yaml
# rloo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: HuggingFaceH4/Polaris-Dataset-53K
reward_funcs:
- accuracy_reward
num_processes: 4
```
Launch with:
```bash
trl rloo --config rloo_config.yaml
```
</hfoption>
<hfoption id="KTO">
```bash
trl kto \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/kto-mix-14k \
--num_processes 4
```
or, with a config file:
```yaml
# kto_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/kto-mix-14k
num_processes: 4
```
Launch with:
```bash
trl kto --config kto_config.yaml
```
</hfoption>
</hfoptions>
### Using `--accelerate_config` for Accelerate Configuration
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
- the name of a predefined config profile (built into TRL), or
- a path to a custom Accelerate YAML config file.
#### Predefined Config Profiles
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
| Name | Description |
| --- | --- |
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
| `zero1` | DeepSpeed ZeRO Stage 1 |
| `zero2` | DeepSpeed ZeRO Stage 2 |
| `zero3` | DeepSpeed ZeRO Stage 3 |
| `multi_gpu` | Multi-GPU training |
| `single_gpu` | Single-GPU training |
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
#### Example Usage
<hfoptions id="trainer">
<hfoption id="SFT">
```bash
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
or, with a config file:
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
or, with a config file:
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
<hfoption id="Reward">
```bash
trl reward \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/ultrafeedback_binarized \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
or, with a config file:
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/ultrafeedback_binarized
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption>
<hfoption id="GRPO">
```bash
trl grpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name HuggingFaceH4/Polaris-Dataset-53K \
--reward_funcs accuracy_reward \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
or, with a config file:
```yaml
# grpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: HuggingFaceH4/Polaris-Dataset-53K
reward_funcs:
- accuracy_reward
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl grpo --config grpo_config.yaml
```
</hfoption>
<hfoption id="RLOO">
```bash
trl rloo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name HuggingFaceH4/Polaris-Dataset-53K \
--reward_funcs accuracy_reward \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
or, with a config file:
```yaml
# rloo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: HuggingFaceH4/Polaris-Dataset-53K
reward_funcs:
- accuracy_reward
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl rloo --config rloo_config.yaml
```
</hfoption>
<hfoption id="KTO">
```bash
trl kto \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/kto-mix-14k \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
or, with a config file:
```yaml
# kto_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/kto-mix-14k
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl kto --config kto_config.yaml
```
</hfoption>
</hfoptions>
### Using dataset mixtures
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
<hfoptions id="trainer">
<hfoption id="SFT">
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: stanfordnlp/imdb
- path: roneneldan/TinyStories
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: BAAI/Infinity-Preference
- path: argilla/Capybara-Preferences
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
<hfoption id="Reward">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: trl-lib/tldr-preference
- path: trl-lib/lm-human-preferences-sentiment
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption>
<hfoption id="GRPO">
```yaml
# grpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: HuggingFaceH4/Polaris-Dataset-53K
- path: trl-lib/DeepMath-103K
reward_funcs:
- accuracy_reward
```
Launch with:
```bash
trl grpo --config grpo_config.yaml
```
</hfoption>
<hfoption id="RLOO">
```yaml
# rloo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: HuggingFaceH4/Polaris-Dataset-53K
- path: trl-lib/DeepMath-103K
reward_funcs:
- accuracy_reward
```
Launch with:
```bash
trl rloo --config rloo_config.yaml
```
</hfoption>
<hfoption id="KTO">
```yaml
# kto_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: trl-lib/kto-mix-14k
- path: argilla/ultrafeedback-binarized-preferences-cleaned
```
Launch with:
```bash
trl kto --config kto_config.yaml
```
</hfoption>
</hfoptions>
To see all the available keywords for defining dataset mixtures, refer to the [`scripts.utils.DatasetConfig`] and [`DatasetMixtureConfig`] classes.
## Getting the System Information
You can get the system information by running the following command:
```bash
trl env
```
This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
```txt
Copy-paste the following information when reporting an issue:
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.1
- accelerator(s): NVIDIA H100 80GB HBM3
- Transformers version: 4.45.0.dev0
- Accelerate version: 0.34.2
- Accelerate config:
- compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- mixed_precision: no
- use_cpu: False
- debug: False
- num_processes: 4
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- Datasets version: 3.0.0
- HF Hub version: 0.24.7
- TRL version: 0.12.0.dev0+acb4d70
- bitsandbytes version: 0.41.1
- DeepSpeed version: 0.15.1
- Diffusers version: 0.30.3
- Liger-Kernel version: 0.3.0
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.12.0
- vLLM version: not installed
```
This information is required when reporting an issue.
================================================
FILE: docs/source/community_tutorials.md
================================================
# Community Tutorials
Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
## Language Models
### Tutorials
| Task | Class | Description | Author | Tutorial | Colab |
| --- | --- | --- | --- | --- | --- |
| Reinforcement Learning | [`GRPOTrainer`] | Efficient Online Training with GRPO and vLLM in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/grpo_vllm_online_training) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/grpo_vllm_online_training.ipynb) |
| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) |
| Reinforcement Learning | [`GRPOTrainer`] | RL on LLaMA 3.1-8B with GRPO and Unsloth optimizations | [Andrea Manzoni](https://huggingface.co/AManzoni) | [Link](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | [](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) |
| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) |
| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
| Preference Optimization | [`experimental.orpo.ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
| Step-Level Reasoning | [`GRPOTrainer`] | Supervised Reinforcement Learning (SRL) for step-by-step reasoning with vLLM | [Deepak Swaminathan](https://huggingface.co/s23deepak) | [Link](https://github.com/s23deepak/Supervised-Reinforcement-Learning) | [](https://colab.research.google.com/github/s23deepak/Supervised-Reinforcement-Learning/blob/main/notebooks/srl_grpo_tutorial.ipynb) |
### Videos
| Task | Title | Author | Video |
| --- | --- | --- | --- |
| Instruction tuning | Fine-tuning open AI models using Hugging Face TRL | [Wietse Venema](https://huggingface.co/wietsevenema) | [<img src="https://img.youtube.com/vi/cnGyyM0vOes/0.jpg">](https://youtu.be/cnGyyM0vOes) |
| Instruction tuning | How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset | [Mayurji](https://huggingface.co/iammayur) | [<img src="https://img.youtube.com/vi/jKdXv3BiLu0/0.jpg">](https://youtu.be/jKdXv3BiLu0) |
<details>
<summary>⚠️ Deprecated features notice for "How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset" (click to expand)</summary>
> [!WARNING]
> The tutorial uses two deprecated features:
>
> - `SFTTrainer(..., tokenizer=tokenizer)`: Use `SFTTrainer(..., processing_class=tokenizer)` instead, or simply omit it (it will be inferred from the model).
> - `setup_chat_format(model, tokenizer)`: Use `SFTConfig(..., chat_template_path="Qwen/Qwen3-0.6B")`, where `chat_template_path` specifies the model whose chat template you want to copy.
</details>
## Vision Language Models
### Tutorials
| Task | Class | Description | Author | Tutorial | Colab |
| --- | --- | --- | --- | --- | --- |
| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) |
| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) |
| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) |
| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) |
| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) |
| Object Detection Grounding | [`SFTTrainer`] | Fine tuning a VLM for Object Detection Grounding using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_object_detection_grounding) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_object_detection_grounding.ipynb) |
| Visual QA | [`DPOTrainer`] | Fine-Tuning a Vision Language Model with TRL using MPO | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_mpo) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_mpo.ipynb) |
| Reinforcement Learning | [`GRPOTrainer`] | Post training a VLM for reasoning with GRPO using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_grpo_trl.ipynb) |
## Speech Language Models
### Tutorials
| Task | Class | Description | Author | Tutorial |
| --- | --- | --- | --- | --- |
| Text-to-Speech | [`GRPOTrainer`] | Post training a Speech Language Model with GRPO using TRL | [Steven Zheng](https://huggingface.co/Steveeeeeeen) | [Link](https://huggingface.co/blog/Steveeeeeeen/llasa-grpo) |
## Contributing
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
================================================
FILE: docs/source/cpo_trainer.md
================================================
# CPO Trainer
[](https://huggingface.co/models?other=cpo,trl)
## Overview
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high level, CPO trains models to avoid generating adequate, but not perfect, translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat.
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
## Quick start
This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_cpo.py
from datasets import load_dataset
from trl.experimental.cpo import CPOConfig, CPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO")
trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_cpo.py
```
## Expected dataset type
CPO requires a [preference dataset](dataset_formats#preference). The [`experimental.cpo.CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Example script
We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch examples/scripts/cpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--output_dir Qwen2-0.5B-CPO
```
## Logged metrics
While training and evaluating, we record the following reward metrics:
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
## CPO variants
### Simple Preference Optimization (SimPO)
[Simple Preference Optimization](https://huggingface.co/papers/2405.14734) (SimPO) by [Yu Meng](https://huggingface.co/yumeng5), [Mengzhou Xia](https://huggingface.co/mengzhouxia), and [Danqi Chen](https://huggingface.co/cdq10131) proposes a simpler and more effective preference optimization algorithm than DPO without using a reference model. The key designs in SimPO are (1) using length-normalized log likelihood as the implicit reward, and (2) incorporating a target reward margin in the Bradley-Terry ranking objective. The official code can be found at [princeton-nlp/SimPO](https://github.com/princeton-nlp/SimPO).
The abstract from the paper is the following:
> Direct Preference Optimization (DPO) is a widely used offline preference optimization algorithm that reparameterizes reward functions in reinforcement learning from human feedback (RLHF) to enhance simplicity and training stability. In this work, we propose SimPO, a simpler yet more effective approach. The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as the implicit reward. This reward formulation better aligns with model generation and eliminates the need for a reference model, making it more compute and memory efficient. Additionally, we introduce a target reward margin to the Bradley-Terry objective to encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. We compare SimPO to DPO and its latest variants across various state-of-the-art training setups, including both base and instruction-tuned models like Mistral and Llama3. We evaluated on extensive instruction-following benchmarks, including AlpacaEval 2, MT-Bench, and the recent challenging Arena-Hard benchmark. Our results demonstrate that SimPO consistently and significantly outperforms existing approaches without substantially increasing response length. Specifically, SimPO outperforms DPO by up to 6.4 points on AlpacaEval 2 and by up to 7.5 points on Arena-Hard. Our top-performing model, built on Llama3-8B-Instruct, achieves a remarkable 44.7 length-controlled win rate on AlpacaEval 2 -- surpassing Claude 3 Opus on the leaderboard, and a 33.8 win rate on Arena-Hard -- making it the strongest 8B open-source model.
The SimPO loss is integrated in the [`experimental.cpo.CPOTrainer`], as it's an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, just turn on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`experimental.cpo.CPOConfig`] and set the `simpo_gamma` to a recommended value.
### CPO-SimPO
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`experimental.cpo.CPOConfig`].
### AlphaPO
The [AlphaPO -- Reward shape matters for LLM alignment](https://huggingface.co/papers/2501.03884) (AlphaPO) method by Aman Gupta, Shao Tang, Qingquan Song, Sirou Zhu, [Jiwoo Hong](https://huggingface.co/JW17), Ankan Saha, Viral Gupta, Noah Lee, Eunki Kim, Jason Zhu, Natesh Pillai, and S. Sathiya Keerthi is also implemented in the [`experimental.cpo.CPOTrainer`]. AlphaPO is an alternative method that applies a transformation to the reward function shape in the context of SimPO loss. The abstract from the paper is the following:
> Reinforcement Learning with Human Feedback (RLHF) and its variants have made huge strides toward the effective alignment of large language models (LLMs) to follow instructions and reflect human values. More recently, Direct Alignment Algorithms (DAAs) have emerged in which the reward modeling stage of RLHF is skipped by characterizing the reward directly as a function of the policy being learned. Some popular examples of DAAs include Direct Preference Optimization (DPO) and Simple Preference Optimization (SimPO). These methods often suffer from likelihood displacement, a phenomenon by which the probabilities of preferred responses are often reduced undesirably. In this paper, we argue that, for DAAs the reward (function) shape matters. We introduce AlphaPO, a new DAA method that leverages an α-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and overoptimization. Compared to SimPO, one of the best performing DAAs, AlphaPO leads to about 7% to 10% relative improvement in alignment performance for the instruct versions of Mistral-7B and Llama3-8B while achieving 15% to 50% relative improvement over DPO on the same models. The analysis and results presented highlight the importance of the reward shape and how one can systematically change it to affect training dynamics, as well as improve alignment performance.
To use this loss as described in the paper, we can set the `loss_type="alphapo"` which automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values in the [`experimental.cpo.CPOConfig`]. Alternatively, you can manually set `loss_type="simpo"`, `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values. Other variants of this method are also possible, such as setting `loss_type="ipo"` and `alpha` to any non-zero value.
## Loss functions
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`experimental.cpo.CPOConfig`]. The following loss functions are supported:
| `loss_type=` | Description |
| --- | --- |
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). |
| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`experimental.cpo.CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`experimental.cpo.CPOConfig`] and `simpo_gamma` to a recommended value. |
| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`experimental.cpo.CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. |
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g., [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## CPOTrainer
[[autodoc]] experimental.cpo.CPOTrainer
- train
- save_model
- push_to_hub
## CPOConfig
[[autodoc]] experimental.cpo.CPOConfig
================================================
FILE: docs/source/customization.md
================================================
# Training customization
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques.
> [!NOTE]
> Although these examples use the [`DPOTrainer`], these customization methods apply to most (if not all) trainers in TRL.
## Use different optimizers and schedulers
By default, the [`DPOTrainer`] creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to [`DPOTrainer`] as follows:
```python
from datasets import load_dataset
from torch import optim
from transformers import AutoModelForCausalLM
from trl import DPOTrainer
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
optimizer = optim.SGD(model.parameters(), lr=1e-6)
trainer = DPOTrainer(
model=model,
train_dataset=dataset,
optimizers=(optimizer, None),
)
trainer.train()
```
### Add a learning rate scheduler
You can also add learning rate schedulers by passing both optimizer and scheduler:
```python
from torch import optim
optimizer = optim.AdamW(model.parameters(), lr=1e-6)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
trainer = DPOTrainer(..., optimizers=(optimizer, lr_scheduler))
```
## Pass 8-bit reference models
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft).
```python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config=quantization_config)
trainer = DPOTrainer(..., ref_model=ref_model)
```
## Add custom callbacks
You can customize the training loop by adding callbacks for logging, monitoring, or early stopping. Callbacks allow you to execute custom code at specific points during training.
```python
from transformers import TrainerCallback
class CustomLoggingCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is not None:
print(f"Step {state.global_step}: {logs}")
trainer = DPOTrainer(..., callbacks=[CustomLoggingCallback()])
```
## Add custom evaluation metrics
You can define custom evaluation metrics to track during training. This is useful for monitoring model performance on specific tasks.
```python
def compute_metrics(eval_preds):
logits, labels = eval_preds
# Add your metric computation here
return {"custom_metric": 0.0}
training_args = DPOConfig(..., eval_strategy="steps", eval_steps=100)
trainer = DPOTrainer(..., eval_dataset=eval_dataset, compute_metrics=compute_metrics)
```
## Use mixed precision training
Mixed precision training can significantly speed up training and reduce memory usage. You can enable it by setting `bf16=True` or `fp16=True` in the training config.
```python
# Use bfloat16 precision (recommended for modern GPUs)
training_args = DPOConfig(..., bf16=True)
```
Note: Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs.
## Use gradient accumulation
When training with limited GPU memory, gradient accumulation allows you to simulate larger batch sizes by accumulating gradients over multiple steps before updating weights.
```python
# Simulate a batch size of 32 with per_device_train_batch_size=4 and gradient_accumulation_steps=8
training_args = DPOConfig(
...,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
)
```
================================================
FILE: docs/source/data_utils.md
================================================
# Data Utilities
## is_conversational
[[autodoc]] is_conversational
## maybe_convert_to_chatml
[[autodoc]] maybe_convert_to_chatml
## extract_prompt
[[autodoc]] extract_prompt
## unpair_preference_dataset
[[autodoc]] unpair_preference_dataset
================================================
FILE: docs/source/dataset_formats.md
================================================
# Dataset formats and types
This guide provides an overview of the dataset formats and types supported by each trainer in TRL.
## Overview of the dataset formats and types
- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*.
- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.
<table>
<tr>
<th>Type \ Format</th>
<th>Standard</th>
<th>Conversational</th>
</tr>
<tr>
<td>Language modeling</td>
<td>
<pre><code>{"text": "The sky is blue."}</code></pre>
</td>
<td>
<pre><code>{"messages": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}]}</code></pre>
</td>
</tr>
<tr>
<td>Prompt-only</td>
<td>
<pre><code>{"prompt": "The sky is"}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}]}</code></pre>
</td>
</tr>
<tr>
<td>Prompt-completion</td>
<td>
<pre><code>{"prompt": "The sky is",
"completion": " blue."}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}</code></pre>
</td>
</tr>
</tr>
<tr>
<td>Preference</td>
<td>
<pre><code>{"prompt": "The sky is",
"chosen": " blue.",
"rejected": " green."}</code></pre>
or, with implicit prompt:
<pre><code>{"chosen": "The sky is blue.",
"rejected": "The sky is green."}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}]}</code></pre>
or, with implicit prompt:
<pre><code>{"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}</code></pre>
</td>
</tr>
<td>Unpaired preference</td>
<td>
<pre><code>{"prompt": "The sky is",
"completion": " blue.",
"label": True}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is green."}],
"label": False}</code></pre>
</td>
</tr>
</tr>
<td>Stepwise supervision</td>
<td>
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": ["The fractional part of 9.8 is 0.8.",
"The fractional part of 9.11 is 0.11.",
"0.11 is greater than 0.8.",
"Hence, 9.11 > 9.8."],
"labels": [True, True, False, False]}</code></pre>
</td>
<td></td>
</tr>
</table>
### Formats
#### Standard
The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks:
```python
# Language modeling
language_modeling_example = {"text": "The sky is blue."}
# Preference
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
# Unpaired preference
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
```
#### Conversational
Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text).
```python
messages = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "I'd like to show off how chat templating works!"},
]
```
Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks:
```python
# Prompt-completion
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}
# Preference
preference_example = {
"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}],
}
```
#### Tool Calling
Some chat templates support *tool calling*, which allows the model to interact with external functions—referred to as **tools**—during generation. This extends the conversational capabilities of the model by enabling it to output a `"tool_calls"` field instead of a standard `"content"` message whenever it decides to invoke a tool.
After the assistant initiates a tool call, the tool executes and returns its output. The assistant can then process this output and continue the conversation accordingly.
Here’s a simple example of a tool-calling interaction:
```python
messages = [
{"role": "user", "content": "Turn on the living room lights."},
{"role": "assistant", "tool_calls": [
{"type": "function", "function": {
"name": "control_light",
"arguments": {"room": "living room", "state": "on"}
}}]
},
{"role": "tool", "name": "control_light", "content": "The lights in the living room are now on."},
{"role": "assistant", "content": "Done!"}
]
```
When preparing datasets for Supervised Fine-Tuning (SFT) with tool calling, it is important that your dataset includes an additional column named `tools`. This column contains the list of available tools for the model, which is usually used by the chat template to construct the system prompt.
The tools must be specified in a codified JSON schema format. You can automatically generate this schema from Python function signatures using the [`~transformers.utils.get_json_schema`] utility:
```python
import json
from transformers.utils import get_json_schema
def control_light(room: str, state: str) -> str:
"""
Controls the lights in a room.
Args:
room: The name of the room.
state: The desired state of the light ("on" or "off").
Returns:
str: A message indicating the new state of the lights.
"""
return f"The lights in {room} are now {state}."
# Generate JSON schema
json_schema = get_json_schema(control_light)
```
The generated schema would look like:
```python
{"type": "function", "function": {"name": "control_light", "description": "Controls the lights in a room.", "parameters": {"type": "object", "properties": {"room": {"type": "string", "description": "The name of the room."}, "state": {"type": "string", "description": "The desired state of the light (\"on\" or \"off\")."}}, "required": ["room", "state"]}, "return": {"type": "string", "description": "str: A message indicating the new state of the lights."}}}
```
A complete dataset entry for SFT might look like:
```python
{"messages": messages, "tools": [json_schema]}
```
To get a `Dataset` you need to use the `Json()` type for tool arguments since they are arbitrary JSON objects, and not dictionaries with fixed fields and types:
```python
from datasets import Dataset
data = [
{"messages": messages1, "tools": [json_schema1]},
{"messages": messages2, "tools": [json_schema2]},
]
# auto-apply the Json() type
dataset = Dataset.from_list(data, on_mixed_types="use_json")
# or specify the features manually
from datasets import Features, Json, List, Value
features = Features(
{
"messages": List({"role": Value("string"), "content": Value("string"), "tool_calls": List(Json())}),
"tools": List(Json()),
}
)
dataset = Dataset.from_list(data, features=features)
```
On older versions of `datasets` (<4.7.0) that don't have the `Json()` type, you should store `tools` as a JSON `str` (with `json.dumps([...])`):
```python
dataset = Dataset.from_list(
[{"messages": messages1, "tools": json.dumps([json_schema1])},
{"messages": messages2, "tools": json.dumps([json_schema2])}]
)
```
For more detailed information on tool calling, refer to the [Tool Calling section in the `transformers` documentation](https://huggingface.co/docs/transformers/chat_extras#tools-and-rag) and the blog post [Tool Use, Unified](https://huggingface.co/blog/unified-tool-use).
### Harmony
The [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) was introduced with the [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4). It extends the conversational format by adding richer structure for reasoning, function calls, and metadata about the model’s behavior. Key features include:
- **Developer role** – Provides high level instructions (similar to a system prompt) and lists available tools.
- **Channels** – Separate types of assistant output into distinct streams:
- `analysis` – for internal reasoning, from the key `"thinking"`
- `final` – for the user-facing answer, from the key `"content"`
- `commentary` – for tool calls or meta notes
- **Reasoning effort** – Signals how much thinking the model should show (e.g., `"low"`, `"medium"`, `"high"`).
- **Model identity** – Explicitly defines the assistant’s persona.
```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
messages = [
{"role": "developer", "content": "Use a friendly tone."},
{"role": "user", "content": "What is the meaning of life?"},
{"role": "assistant", "thinking": "Deep reflection...", "content": "The final answer is..."},
]
print(
tokenizer.apply_chat_template(
messages,
tokenize=False,
reasoning_effort="low",
model_identity="You are HuggingGPT, a large language model trained by Hugging Face.",
)
)
```
This produces:
```txt
<|start|>system<|message|>You are HuggingGPT, a large language model trained by Hugging Face.
Knowledge cutoff: 2024-06
Current date: 2025-08-03
Reasoning: low
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
Use a friendly tone.<|end|><|start|>user<|message|>What is the meaning of life?<|end|><|start|>assistant<|channel|>analysis<|message|>Deep reflection...<|end|><|start|>assistant<|channel|>final<|message|>The final answer is...<|return|>
```
For full details on message structure, supported fields, and advanced usage, see the [Harmony documentation](https://cookbook.openai.com/articles/openai-harmony).
### Types
#### Language modeling
A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
```python
# Standard format
language_modeling_example = {"text": "The sky is blue."}
# Conversational format
language_modeling_example = {"messages": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}
]}
```
#### Prompt-only
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating completion based on this prompt, where the model learns to continue or complete the given input.
```python
# Standard format
prompt_only_example = {"prompt": "The sky is"}
# Conversational format
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
```
For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
> [!TIP]
> While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
>
> ```python
> from transformers import AutoTokenizer
> from trl import apply_chat_template
>
> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
>
> # Example for prompt-only type
> prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
> apply_chat_template(prompt_only_example, tokenizer)
> # Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
>
> # Example for language modeling type
> lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
> apply_chat_template(lm_example, tokenizer)
> # Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
> ```
>
> - The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
> - In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
#### Prompt-completion
A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
```python
# Standard format
prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
# Conversational format
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}
```
For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216).
#### Preference
A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
Some datasets may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
```python
# Standard format
## Explicit prompt (recommended)
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
# Implicit prompt
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
# Conversational format
## Explicit prompt (recommended)
preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}]}
## Implicit prompt
preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}
```
For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c).
Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
#### Unpaired preference
An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
```python
# Standard format
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
# Conversational format
unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}],
"label": True}
```
For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf).
#### Stepwise supervision
A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process.
```python
stepwise_example = {
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."],
"labels": [True, False]
}
```
For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e).
## Which dataset type to use?
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
| Trainer | Expected dataset type |
| --- | --- |
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`experimental.kto.KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`experimental.online_dpo.OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling |
| [`experimental.prm.PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`experimental.xpo.XPOTrainer`] | [Prompt-only](#prompt-only) |
## Using any dataset with TRL: preprocessing and conversion
Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions.
### Example: UltraFeedback dataset
Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset:
<iframe
src="https://huggingface.co/datasets/openbmb/UltraFeedback/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>
As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty").
By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub:
```sh
python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness
```
Once converted, the dataset will look like this:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Now, you can use this dataset with TRL!
By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL.
## Utilities for converting dataset types
This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently.
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
### From prompt-completion to language modeling dataset
To convert a prompt-completion dataset into a language modeling dataset, concatenate the prompt and the completion.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"completion": [" blue.", " in the sky."],
})
def concat_prompt_completion(example):
return {"text": example["prompt"] + example["completion"]}
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
```
```python
>>> dataset[0]
{'text': 'The sky is blue.'}
```
### From prompt-completion to prompt-only dataset
To convert a prompt-completion dataset into a prompt-only dataset, remove the completion.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"completion": [" blue.", " in the sky."],
})
dataset = dataset.remove_columns("completion")
```
```python
>>> dataset[0]
{'prompt': 'The sky is'}
```
### From preference with implicit prompt to language modeling dataset
To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"chosen": ["The sky is blue.", "The sun is in the sky."],
"rejected": ["The sky is green.", "The sun is in the sea."],
})
dataset = dataset.rename_column("chosen", "text").remove_columns("rejected")
```
```python
>>> dataset[0]
{'text': 'The sky is blue.'}
```
### From preference with implicit prompt to prompt-completion dataset
To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`.
```python
from datasets import Dataset
from trl import extract_prompt
dataset = Dataset.from_dict({
"chosen": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
],
})
dataset = dataset.map(extract_prompt).remove_columns("rejected").rename_column("chosen", "completion")
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], 'completion': [{'role': 'assistant', 'content': 'It is blue.'}]}
```
### From preference with implicit prompt to prompt-only dataset
To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen.
```python
from datasets import Dataset
from trl import extract_prompt
dataset = Dataset.from_dict({
"chosen": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
],
})
dataset = dataset.map(extract_prompt).remove_columns(["chosen", "rejected"])
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}]}
```
### From implicit to explicit prompt preference dataset
To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`].
```python
from datasets import Dataset
from trl import extract_prompt
dataset = Dataset.from_dict({
"chosen": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
],
})
dataset = dataset.map(extract_prompt)
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
'chosen': [{'role': 'assistant', 'content': 'It is blue.'}],
'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}
```
### From preference with implicit prompt to unpaired preference dataset
To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`].
```python
from datasets import Dataset
from trl import extract_prompt, unpair_preference_dataset
dataset = Dataset.from_dict({
"chosen": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
],
})
dataset = dataset.map(extract_prompt)
dataset = unpair_preference_dataset(dataset)
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
'label': True}
```
> [!WARNING]
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
### From preference to language modeling dataset
To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"chosen": [" blue.", " in the sky."],
"rejected": [" green.", " in the sea."],
})
def concat_prompt_chosen(example):
return {"text": example["prompt"] + example["chosen"]}
dataset = dataset.map(concat_prompt_chosen, remove_columns=["prompt", "chosen", "rejected"])
```
```python
>>> dataset[0]
{'text': 'The sky is blue.'}
```
### From preference to prompt-completion dataset
To convert a preference dataset into a prompt-completion dataset, remove the rejected, and rename the column `"chosen"` to `"completion"`.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"chosen": [" blue.", " in the sky."],
"rejected": [" green.", " in the sea."],
})
dataset = dataset.remove_columns("rejected").rename_column("chosen", "completion")
```
```python
>>> dataset[0]
{'prompt': 'The sky is', 'completion': ' blue.'}
```
### From preference to prompt-only dataset
To convert a preference dataset into a prompt-only dataset, remove the rejected and the chosen.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"chosen": [" blue.", " in the sky."],
"rejected": [" green.", " in the sea."],
})
dataset = dataset.remove_columns(["chosen", "rejected"])
```
```python
>>> dataset[0]
{'prompt': 'The sky is'}
```
### From explicit to implicit prompt preference dataset
To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": [
[{"role": "user", "content": "What color is the sky?"}],
[{"role": "user", "content": "Where is the sun?"}],
],
"chosen": [
[{"role": "assistant", "content": "It is blue."}],
[{"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "assistant", "content": "It is green."}],
[{"role": "assistant", "content": "In the sea."}],
],
})
def concat_prompt_to_completions(example):
return {"chosen": example["prompt"] + example["chosen"], "rejected": example["prompt"] + example["rejected"]}
dataset = dataset.map(concat_prompt_to_completions, remove_columns="prompt")
```
```python
>>> dataset[0]
{'chosen': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is blue.'}],
'rejected': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is green.'}]}
```
### From preference to unpaired preference dataset
To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`].
```python
from datasets import Dataset
from trl import unpair_preference_dataset
dataset = Dataset.from_dict({
"prompt": [
[{"role": "user", "content": "What color is the sky?"}],
[{"role": "user", "content": "Where is the sun?"}],
],
"chosen": [
[{"role": "assistant", "content": "It is blue."}],
[{"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "assistant", "content": "It is green."}],
[{"role": "assistant", "content": "In the sea."}],
],
})
dataset = unpair_preference_dataset(dataset)
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
'label': True}
```
> [!WARNING]
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
### From unpaired preference to language modeling dataset
To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
"label": [True, True, False, False],
})
def concatenate_prompt_completion(example):
return {"text": example["prompt"] + example["completion"]}
dataset = dataset.filter(lambda x: x["label"]).map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
```
```python
>>> dataset[0]
{'text': 'The sky is blue.'}
```
### From unpaired preference to prompt-completion dataset
To convert an unpaired preference dataset into a prompt-completion dataset, filter for good labels, then remove the label columns.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
"label": [True, True, False, False],
})
dataset = dataset.filter(lambda x: x["label"]).remove_columns(["label"])
```
```python
>>> dataset[0]
{'prompt': 'The sky is', 'completion': ' blue.'}
```
### From unpaired preference to prompt-only dataset
To convert an unpaired preference dataset into a prompt-only dataset, remove the completion and the label columns.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
"label": [True, True, False, False],
})
dataset = dataset.remove_columns(["completion", "label"])
```
```python
>>> dataset[0]
{'prompt': 'The sky is'}
```
### From stepwise supervision to language modeling dataset
To convert a stepwise supervision dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
def concatenate_prompt_completions(example):
completion = "".join(example["completions"])
return {"text": example["prompt"] + completion}
dataset = dataset.filter(lambda x: all(x["labels"])).map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"])
```
```python
>>> dataset[0]
{'text': 'Blue light scatters more in the atmosphere, so the sky is green.'}
```
### From stepwise supervision to prompt-completion dataset
To convert a stepwise supervision dataset into a prompt-completion dataset, join the good completions and remove the labels.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
def join_completions(example):
completion = "".join(example["completions"])
return {"completion": completion}
dataset = dataset.filter(lambda x: all(x["labels"])).map(join_completions, remove_columns=["completions", "labels"])
```
```python
>>> dataset[0]
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'}
```
### From stepwise supervision to prompt-only dataset
To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
dataset = dataset.remove_columns(["completions", "labels"])
```
```python
>>> dataset[0]
{'prompt': 'Blue light'}
```
### From stepwise supervision to unpaired preference dataset
To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels.
The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
def merge_completions_and_labels(example):
return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])}
dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "l
gitextract_r678upi2/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug-report.yml
│ │ ├── feature-request.yml
│ │ └── new-trainer-addition.yml
│ ├── PULL_REQUEST_TEMPLATE.md
│ ├── codeql/
│ │ └── custom-queries.qls
│ └── workflows/
│ ├── build_documentation.yml
│ ├── build_pr_documentation.yml
│ ├── clear_cache.yml
│ ├── codeQL.yml
│ ├── docker-build.yml
│ ├── issue_auto_labeller.yml
│ ├── pr_style_bot.yml
│ ├── publish.yml
│ ├── slow-tests.yml
│ ├── tests-experimental.yml
│ ├── tests.yml
│ ├── tests_latest.yml
│ ├── tests_transformers_branch.yml
│ ├── trufflehog.yml
│ └── upload_pr_documentation.yml
├── .gitignore
├── .pre-commit-config.yaml
├── AGENTS.md
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── MIGRATION.md
├── Makefile
├── README.md
├── RELEASE.md
├── VERSION
├── docker/
│ ├── trl/
│ │ └── Dockerfile
│ └── trl-dev/
│ └── Dockerfile
├── docs/
│ └── source/
│ ├── _toctree.yml
│ ├── async_grpo_trainer.md
│ ├── bco_trainer.md
│ ├── bema_for_reference_model.md
│ ├── callbacks.md
│ ├── chat_template_utils.md
│ ├── clis.md
│ ├── community_tutorials.md
│ ├── cpo_trainer.md
│ ├── customization.md
│ ├── data_utils.md
│ ├── dataset_formats.md
│ ├── deepspeed_integration.md
│ ├── distributing_training.md
│ ├── dpo_trainer.md
│ ├── example_overview.md
│ ├── experimental_overview.md
│ ├── gfpo.md
│ ├── gkd_trainer.md
│ ├── gold_trainer.md
│ ├── grpo_trainer.md
│ ├── grpo_with_replay_buffer.md
│ ├── gspo_token.md
│ ├── index.md
│ ├── installation.md
│ ├── jobs_training.md
│ ├── judges.md
│ ├── kernels_hub.md
│ ├── kto_trainer.md
│ ├── liger_kernel_integration.md
│ ├── lora_without_regret.md
│ ├── merge_model_callback.md
│ ├── minillm_trainer.md
│ ├── nash_md_trainer.md
│ ├── nemo_gym.md
│ ├── online_dpo_trainer.md
│ ├── openenv.md
│ ├── orpo_trainer.md
│ ├── paper_index.md
│ ├── papo_trainer.md
│ ├── peft_integration.md
│ ├── ppo_trainer.md
│ ├── prm_trainer.md
│ ├── ptt_integration.md
│ ├── quickstart.md
│ ├── rapidfire_integration.md
│ ├── reducing_memory_usage.md
│ ├── reward_trainer.md
│ ├── rewards.md
│ ├── rloo_trainer.md
│ ├── script_utils.md
│ ├── sft_trainer.md
│ ├── speeding_up_training.md
│ ├── trackio_integration.md
│ ├── unsloth_integration.md
│ ├── use_model.md
│ ├── vllm_integration.md
│ ├── winrate_callback.md
│ └── xpo_trainer.md
├── examples/
│ ├── README.md
│ ├── accelerate_configs/
│ │ ├── alst_ulysses_4gpu.yaml
│ │ ├── context_parallel_2gpu.yaml
│ │ ├── deepspeed_zero1.yaml
│ │ ├── deepspeed_zero2.yaml
│ │ ├── deepspeed_zero3.yaml
│ │ ├── fsdp1.yaml
│ │ ├── fsdp2.yaml
│ │ ├── multi_gpu.yaml
│ │ └── single_gpu.yaml
│ ├── cli_configs/
│ │ └── example_config.yaml
│ ├── datasets/
│ │ ├── deepmath_103k.py
│ │ ├── hh-rlhf-helpful-base.py
│ │ ├── llava_instruct_mix.py
│ │ ├── lm-human-preferences-descriptiveness.py
│ │ ├── lm-human-preferences-sentiment.py
│ │ ├── math_shepherd.py
│ │ ├── prm800k.py
│ │ ├── rlaif-v.py
│ │ ├── tldr.py
│ │ ├── tldr_preference.py
│ │ ├── ultrafeedback-prompt.py
│ │ └── ultrafeedback.py
│ ├── notebooks/
│ │ ├── README.md
│ │ ├── grpo_agent.ipynb
│ │ ├── grpo_functiongemma_browsergym_openenv.ipynb
│ │ ├── grpo_ministral3_vl.ipynb
│ │ ├── grpo_qwen3_vl.ipynb
│ │ ├── grpo_rnj_1_instruct.ipynb
│ │ ├── grpo_trl_lora_qlora.ipynb
│ │ ├── openenv_sudoku_grpo.ipynb
│ │ ├── openenv_wordle_grpo.ipynb
│ │ ├── sft_ministral3_vl.ipynb
│ │ ├── sft_nemotron_3.ipynb
│ │ ├── sft_qwen_vl.ipynb
│ │ ├── sft_tool_calling.ipynb
│ │ └── sft_trl_lora_qlora.ipynb
│ └── scripts/
│ ├── async_grpo.py
│ ├── bco.py
│ ├── cpo.py
│ ├── dpo.py
│ ├── dpo_vlm.py
│ ├── evals/
│ │ └── judge_tldr.py
│ ├── gkd.py
│ ├── grpo_2048.py
│ ├── grpo_agent.py
│ ├── grpo_vlm.py
│ ├── gspo.py
│ ├── gspo_vlm.py
│ ├── kto.py
│ ├── mpo_vlm.py
│ ├── nash_md.py
│ ├── nemo_gym/
│ │ ├── README.md
│ │ ├── config.yaml
│ │ ├── deepspeed_zero3.yaml
│ │ ├── submit.sh
│ │ └── train_multi_environment.py
│ ├── online_dpo.py
│ ├── online_dpo_vlm.py
│ ├── openenv/
│ │ ├── browsergym.py
│ │ ├── browsergym_llm.py
│ │ ├── carla.py
│ │ ├── catch.py
│ │ ├── echo.py
│ │ ├── sudoku.py
│ │ ├── sudoku_prompt.txt
│ │ └── wordle.py
│ ├── orpo.py
│ ├── ppo/
│ │ ├── ppo.py
│ │ └── ppo_tldr.py
│ ├── prm.py
│ ├── reward_modeling.py
│ ├── rloo.py
│ ├── rloo_vlm.py
│ ├── sft.py
│ ├── sft_gemma3.py
│ ├── sft_gpt_oss.py
│ ├── sft_nemotron_3.py
│ ├── sft_tiny_aya_tool_calling.py
│ ├── sft_video_llm.py
│ ├── sft_vlm.py
│ ├── sft_vlm_gemma3.py
│ ├── tiny_aya_chat_template.jinja
│ └── xpo.py
├── pyproject.toml
├── requirements.txt
├── scripts/
│ ├── add_copyrights.py
│ ├── generate_harmony_dataset.py
│ ├── generate_tiny_models.py
│ ├── generate_toolcall_dataset.py
│ ├── generate_zen_dataset.py
│ ├── generate_zen_image_dataset.py
│ ├── generate_zen_multi_image_dataset.py
│ └── log_reports.py
├── tests/
│ ├── __init__.py
│ ├── conftest.py
│ ├── data/
│ │ └── template.jinja
│ ├── distributed/
│ │ ├── __init__.py
│ │ ├── data/
│ │ │ └── accelerate_configs/
│ │ │ ├── ddp.yaml
│ │ │ ├── fsdp2.yaml
│ │ │ ├── zero2.yaml
│ │ │ └── zero3.yaml
│ │ └── test_distributed.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ ├── test_async_grpo_trainer.py
│ │ ├── test_bco_trainer.py
│ │ ├── test_cpo_trainer.py
│ │ ├── test_dppo_trainer.py
│ │ ├── test_gkd_trainer.py
│ │ ├── test_gold_trainer.py
│ │ ├── test_grpo_with_replay_buffer_trainer.py
│ │ ├── test_gspo_token_trainer.py
│ │ ├── test_judges.py
│ │ ├── test_kto_trainer.py
│ │ ├── test_merge_model_callback.py
│ │ ├── test_minillm_trainer.py
│ │ ├── test_modeling_value_head.py
│ │ ├── test_nash_md_trainer.py
│ │ ├── test_online_dpo_trainer.py
│ │ ├── test_orpo_trainer.py
│ │ ├── test_ppo_trainer.py
│ │ ├── test_prm_trainer.py
│ │ ├── test_utils.py
│ │ ├── test_winrate_callback.py
│ │ ├── test_xpo_trainer.py
│ │ └── testing_utils.py
│ ├── test_activation_offloading.py
│ ├── test_callbacks.py
│ ├── test_chat_template_utils.py
│ ├── test_cli.py
│ ├── test_cli_utils.py
│ ├── test_data_utils.py
│ ├── test_dpo_trainer.py
│ ├── test_grpo_trainer.py
│ ├── test_model_utils.py
│ ├── test_reward_trainer.py
│ ├── test_rewards.py
│ ├── test_rich_progress_callback.py
│ ├── test_rloo_trainer.py
│ ├── test_sft_trainer.py
│ ├── test_skills.py
│ ├── test_skills_cli.py
│ ├── test_utils.py
│ ├── test_vllm_client_server.py
│ ├── testing_constants.py
│ └── testing_utils.py
└── trl/
├── __init__.py
├── _compat.py
├── _lazy_module.py
├── accelerate_configs/
│ ├── fsdp1.yaml
│ ├── fsdp2.yaml
│ ├── multi_gpu.yaml
│ ├── single_gpu.yaml
│ ├── zero1.yaml
│ ├── zero2.yaml
│ └── zero3.yaml
├── chat_template_utils.py
├── cli/
│ ├── __init__.py
│ ├── accelerate_config.py
│ ├── accelerate_launcher.py
│ ├── commands/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── env.py
│ │ ├── skills.py
│ │ ├── training.py
│ │ └── vllm_serve.py
│ └── main.py
├── data_utils.py
├── experimental/
│ ├── __init__.py
│ ├── async_grpo/
│ │ ├── __init__.py
│ │ ├── async_grpo_config.py
│ │ ├── async_grpo_trainer.py
│ │ └── async_rollout_worker.py
│ ├── bco/
│ │ ├── __init__.py
│ │ ├── bco_config.py
│ │ └── bco_trainer.py
│ ├── bema_for_ref_model/
│ │ ├── __init__.py
│ │ ├── callback.py
│ │ └── dpo_trainer.py
│ ├── cpo/
│ │ ├── __init__.py
│ │ ├── cpo_config.py
│ │ └── cpo_trainer.py
│ ├── dppo/
│ │ ├── __init__.py
│ │ ├── dppo_config.py
│ │ └── dppo_trainer.py
│ ├── gfpo/
│ │ ├── __init__.py
│ │ ├── gfpo_config.py
│ │ └── gfpo_trainer.py
│ ├── gkd/
│ │ ├── __init__.py
│ │ ├── gkd_config.py
│ │ └── gkd_trainer.py
│ ├── gold/
│ │ ├── __init__.py
│ │ ├── gold.py
│ │ ├── gold_config.py
│ │ └── gold_trainer.py
│ ├── grpo_with_replay_buffer/
│ │ ├── __init__.py
│ │ ├── grpo_with_replay_buffer_config.py
│ │ └── grpo_with_replay_buffer_trainer.py
│ ├── gspo_token/
│ │ ├── __init__.py
│ │ └── grpo_trainer.py
│ ├── judges/
│ │ ├── __init__.py
│ │ └── judges.py
│ ├── kto/
│ │ ├── __init__.py
│ │ ├── kto_config.py
│ │ └── kto_trainer.py
│ ├── merge_model_callback.py
│ ├── minillm/
│ │ ├── __init__.py
│ │ ├── minillm_config.py
│ │ └── minillm_trainer.py
│ ├── nash_md/
│ │ ├── __init__.py
│ │ ├── nash_md_config.py
│ │ └── nash_md_trainer.py
│ ├── online_dpo/
│ │ ├── __init__.py
│ │ ├── online_dpo_config.py
│ │ └── online_dpo_trainer.py
│ ├── openenv/
│ │ ├── __init__.py
│ │ └── utils.py
│ ├── orpo/
│ │ ├── __init__.py
│ │ ├── orpo_config.py
│ │ └── orpo_trainer.py
│ ├── papo/
│ │ ├── __init__.py
│ │ ├── papo_config.py
│ │ └── papo_trainer.py
│ ├── ppo/
│ │ ├── __init__.py
│ │ ├── modeling_value_head.py
│ │ ├── ppo_config.py
│ │ └── ppo_trainer.py
│ ├── prm/
│ │ ├── __init__.py
│ │ ├── prm_config.py
│ │ └── prm_trainer.py
│ ├── utils.py
│ ├── winrate_callback.py
│ └── xpo/
│ ├── __init__.py
│ ├── xpo_config.py
│ └── xpo_trainer.py
├── extras/
│ ├── __init__.py
│ ├── dataset_formatting.py
│ └── profiling.py
├── generation/
│ ├── __init__.py
│ ├── vllm_client.py
│ └── vllm_generation.py
├── import_utils.py
├── models/
│ ├── __init__.py
│ ├── activation_offloading.py
│ └── utils.py
├── py.typed
├── rewards/
│ ├── __init__.py
│ ├── accuracy_rewards.py
│ ├── format_rewards.py
│ └── other_rewards.py
├── scripts/
│ ├── __init__.py
│ ├── _hf_argparser.py
│ ├── dpo.py
│ ├── env.py
│ ├── grpo.py
│ ├── kto.py
│ ├── reward.py
│ ├── rloo.py
│ ├── sft.py
│ ├── utils.py
│ └── vllm_serve.py
├── skills/
│ ├── __init__.py
│ ├── cli.py
│ ├── skills.py
│ └── trl-training/
│ └── SKILL.md
├── templates/
│ ├── completions_dataset_card.md
│ ├── lm_model_card.md
│ └── rm_model_card.md
└── trainer/
├── __init__.py
├── base_config.py
├── base_trainer.py
├── callbacks.py
├── dpo_config.py
├── dpo_trainer.py
├── grpo_config.py
├── grpo_trainer.py
├── kto_config.py
├── kto_trainer.py
├── model_config.py
├── reward_config.py
├── reward_trainer.py
├── rloo_config.py
├── rloo_trainer.py
├── sft_config.py
├── sft_trainer.py
└── utils.py
SYMBOL INDEX (1986 symbols across 186 files)
FILE: examples/datasets/deepmath_103k.py
class ScriptArguments (line 23) | class ScriptArguments:
function process_example (line 50) | def process_example(example):
FILE: examples/datasets/hh-rlhf-helpful-base.py
class ScriptArguments (line 24) | class ScriptArguments:
function common_start (line 49) | def common_start(str1: str, str2: str) -> str:
function extract_dialogue (line 61) | def extract_dialogue(example: str) -> list[dict[str, str]]:
FILE: examples/datasets/llava_instruct_mix.py
class ScriptArguments (line 24) | class ScriptArguments:
function process_example (line 51) | def process_example(example):
function filter_long_examples (line 61) | def filter_long_examples(example):
function split_prompt_completion (line 66) | def split_prompt_completion(example):
FILE: examples/datasets/lm-human-preferences-descriptiveness.py
class ScriptArguments (line 23) | class ScriptArguments:
function samples_not_all_same (line 51) | def samples_not_all_same(example):
function to_prompt_completion (line 55) | def to_prompt_completion(example, tokenizer):
FILE: examples/datasets/lm-human-preferences-sentiment.py
class ScriptArguments (line 23) | class ScriptArguments:
function to_prompt_completion (line 50) | def to_prompt_completion(example, tokenizer):
FILE: examples/datasets/math_shepherd.py
class ScriptArguments (line 25) | class ScriptArguments:
function process_example (line 52) | def process_example(example):
FILE: examples/datasets/prm800k.py
class ScriptArguments (line 23) | class ScriptArguments:
function process_example (line 50) | def process_example(example):
function process_batch (line 89) | def process_batch(examples):
FILE: examples/datasets/rlaif-v.py
class ScriptArguments (line 23) | class ScriptArguments:
function to_conversational (line 50) | def to_conversational(example):
FILE: examples/datasets/tldr.py
class ScriptArguments (line 23) | class ScriptArguments:
function to_prompt_completion (line 50) | def to_prompt_completion(example):
FILE: examples/datasets/tldr_preference.py
class ScriptArguments (line 23) | class ScriptArguments:
function to_preference (line 50) | def to_preference(example):
FILE: examples/datasets/ultrafeedback-prompt.py
class ScriptArguments (line 23) | class ScriptArguments:
function to_unpaired_preference (line 50) | def to_unpaired_preference(example):
function drop_long_prompt (line 55) | def drop_long_prompt(example):
FILE: examples/datasets/ultrafeedback.py
class ScriptArguments (line 23) | class ScriptArguments:
function to_unpaired_preference (line 87) | def to_unpaired_preference(example, model_name, aspect):
FILE: examples/scripts/async_grpo.py
function format_sample (line 39) | def format_sample(sample):
function main (line 43) | def main() -> None:
FILE: examples/scripts/bco.py
function embed_prompt (line 90) | def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.Long...
FILE: examples/scripts/evals/judge_tldr.py
class ScriptArguments (line 54) | class ScriptArguments:
FILE: examples/scripts/grpo_2048.py
class Game2048Env (line 32) | class Game2048Env:
method reset (line 33) | def reset(self, **kwargs) -> str:
method move (line 41) | def move(self, direction: str) -> str:
method _spawn (line 60) | def _spawn(self) -> None:
method _merge_line (line 68) | def _merge_line(line: list[int]) -> tuple[list[int], int]:
method _apply_move (line 85) | def _apply_move(self, direction: str) -> tuple[bool, int]:
method _can_move (line 117) | def _can_move(self) -> bool:
method _render (line 128) | def _render(self) -> str:
function reward_score (line 132) | def reward_score(environments, **kwargs):
function main (line 136) | def main() -> None:
FILE: examples/scripts/grpo_agent.py
function query_reward (line 56) | def query_reward(completions, answer, **kwargs):
function correctness_reward (line 123) | def correctness_reward(completions, answer, **kwargs):
function structure_reward (line 151) | def structure_reward(completions, **kwargs):
class TimeoutError (line 193) | class TimeoutError(Exception):
function timeout (line 200) | def timeout(seconds):
function query_biogrid (line 214) | def query_biogrid(sql_command: str) -> list[tuple]:
function format_example (line 240) | def format_example(example):
FILE: examples/scripts/grpo_vlm.py
function make_conversation (line 119) | def make_conversation(example):
function filter_big_images (line 129) | def filter_big_images(example):
function convert_to_rgb (line 135) | def convert_to_rgb(example):
FILE: examples/scripts/gspo.py
function make_conversation (line 106) | def make_conversation(example):
FILE: examples/scripts/gspo_vlm.py
function make_conversation (line 108) | def make_conversation(example):
function filter_big_images (line 118) | def filter_big_images(example):
function convert_to_rgb (line 124) | def convert_to_rgb(example):
FILE: examples/scripts/mpo_vlm.py
function ensure_rgb (line 104) | def ensure_rgb(example):
FILE: examples/scripts/nemo_gym/train_multi_environment.py
class NeMoGymGRPOConfig (line 40) | class NeMoGymGRPOConfig(GRPOConfig):
function get_agent_servers (line 45) | def get_agent_servers(
function reward_fn (line 76) | def reward_fn(completions: list[str], **kwargs) -> list[float]:
function call_nemo_gym_agents (line 82) | async def call_nemo_gym_agents(
function nemo_gym_rollout_func (line 139) | def nemo_gym_rollout_func(prompts: list[str], trainer: GRPOTrainer) -> d...
function load_dataset_from_jsonl (line 295) | def load_dataset_from_jsonl(path: str) -> Dataset:
function main (line 312) | def main():
FILE: examples/scripts/online_dpo_vlm.py
function make_conversation (line 162) | def make_conversation(example):
function filter_big_images (line 173) | def filter_big_images(example):
function convert_to_rgb (line 179) | def convert_to_rgb(example):
FILE: examples/scripts/openenv/browsergym.py
function parse_args (line 97) | def parse_args() -> argparse.Namespace:
function sanitize_name (line 278) | def sanitize_name(name: str) -> str:
function make_user_prompt (line 309) | def make_user_prompt(goal: str, step_num: int, axtree: str, error: str =...
function parse_action (line 330) | def parse_action(response_text: str) -> str:
function rollout_once (line 342) | def rollout_once(
function reward_completion (line 456) | def reward_completion(completions: list[str], **kwargs) -> list[float]:
function main (line 469) | def main() -> None:
FILE: examples/scripts/openenv/browsergym_llm.py
function parse_args (line 79) | def parse_args() -> argparse.Namespace:
function sanitize_name (line 239) | def sanitize_name(name: str) -> str:
function make_user_prompt (line 273) | def make_user_prompt(goal: str, step_num: int, axtree: str, error: str =...
function parse_action (line 294) | def parse_action(response_text: str) -> str:
function rollout_once (line 306) | def rollout_once(
function reward_completion (line 393) | def reward_completion(completions: list[str], **kwargs) -> list[float]:
function main (line 406) | def main() -> None:
FILE: examples/scripts/openenv/carla.py
function parse_args (line 58) | def parse_args():
class CarlaGRPOEnv (line 113) | class CarlaGRPOEnv:
method __init__ (line 114) | def __init__(self):
method _describe (line 119) | def _describe(obs) -> str:
method _advance (line 132) | def _advance(self, ticks: int = SIM_TICKS):
method reset (line 141) | def reset(self, **kwargs) -> str | None:
method observe (line 146) | def observe(self) -> str:
method emergency_stop (line 157) | def emergency_stop(self) -> str:
method lane_change (line 169) | def lane_change(self, direction: str) -> str:
function reward_func (line 185) | def reward_func(completions, environments, **kwargs):
FILE: examples/scripts/openenv/catch.py
function parse_args (line 91) | def parse_args():
function start_env_server (line 135) | def start_env_server(env_host: str, env_port: int):
function reward_from_env (line 200) | def reward_from_env(completions, **kwargs):
function main (line 205) | def main():
FILE: examples/scripts/openenv/echo.py
function reward_func (line 42) | def reward_func(completions, environments, **kwargs):
class MyEchoEnv (line 46) | class MyEchoEnv:
method __init__ (line 47) | def __init__(self):
method reset (line 50) | def reset(self, **kwargs) -> None | str:
method step (line 54) | def step(self, message: str) -> str:
method get_reward (line 68) | def get_reward(self) -> float:
FILE: examples/scripts/openenv/sudoku.py
function parse_args (line 113) | def parse_args() -> argparse.Namespace:
function resolve_system_prompt (line 189) | def resolve_system_prompt(path: str) -> str:
function sanitize_name (line 196) | def sanitize_name(name: str) -> str:
function extract_sudoku_move (line 200) | def extract_sudoku_move(text: str) -> str:
function is_valid_board_state (line 217) | def is_valid_board_state(board_str: str) -> bool:
function parse_board (line 222) | def parse_board(board_str: str) -> list[list[int]]:
function count_filled_cells (line 244) | def count_filled_cells(board_str: str) -> int:
function get_valid_numbers (line 252) | def get_valid_numbers(grid: list[list[int]], row: int, col: int) -> set[...
function extract_empty_cells_with_candidates (line 279) | def extract_empty_cells_with_candidates(
function extract_empty_cells (line 304) | def extract_empty_cells(board_str: str) -> list[tuple[int, int]]:
function extract_board_only (line 325) | def extract_board_only(text: str) -> str:
function make_compact_prompt (line 353) | def make_compact_prompt(
function check_move_targets_empty_cell (line 417) | def check_move_targets_empty_cell(move: str, board_str: str) -> bool:
function extract_feedback (line 431) | def extract_feedback(observation) -> dict:
function rollout_once (line 457) | def rollout_once(
function reward_empty_cell (line 671) | def reward_empty_cell(completions: list[str], **kwargs) -> list[float]:
function reward_valid_moves (line 679) | def reward_valid_moves(completions: list[str], **kwargs) -> list[float]:
function reward_correct (line 687) | def reward_correct(completions: list[str], **kwargs) -> list[float]:
function reward_repetition (line 695) | def reward_repetition(completions: list[str], **kwargs) -> list[float]:
function reward_progress (line 703) | def reward_progress(completions: list[str], **kwargs) -> list[float]:
function main (line 716) | def main() -> None:
FILE: examples/scripts/openenv/wordle.py
class WordleEnv (line 50) | class WordleEnv:
method __init__ (line 51) | def __init__(self):
method reset (line 54) | def reset(self, **kwargs) -> None | str:
method guess (line 63) | def guess(self, guess: str) -> str:
function reward (line 91) | def reward(environments, **kwargs) -> list[float]:
function main (line 95) | def main() -> None:
FILE: examples/scripts/ppo/ppo.py
function prepare_dataset (line 134) | def prepare_dataset(dataset, tokenizer):
FILE: examples/scripts/ppo/ppo_tldr.py
function prepare_dataset (line 137) | def prepare_dataset(dataset, tokenizer):
FILE: examples/scripts/rloo.py
function main (line 51) | def main():
FILE: examples/scripts/rloo_vlm.py
function make_conversation (line 119) | def make_conversation(example):
function filter_big_images (line 129) | def filter_big_images(example):
function convert_to_rgb (line 135) | def convert_to_rgb(example):
FILE: examples/scripts/sft_gemma3.py
function main (line 42) | def main():
FILE: examples/scripts/sft_gpt_oss.py
function main (line 63) | def main(script_args, training_args, model_args):
FILE: examples/scripts/sft_nemotron_3.py
function main (line 70) | def main(script_args, training_args, model_args):
FILE: examples/scripts/sft_tiny_aya_tool_calling.py
function create_conversation (line 80) | def create_conversation(sample):
function main (line 101) | def main():
FILE: examples/scripts/sft_video_llm.py
function download_video (line 73) | def download_video(url: str, cache_dir: str) -> str:
function prepare_dataset (line 94) | def prepare_dataset(example: dict[str, Any], cache_dir: str) -> dict[str...
function collate_fn (line 124) | def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
class CustomScriptArguments (line 165) | class CustomScriptArguments(ScriptArguments):
FILE: examples/scripts/sft_vlm_gemma3.py
function process_vision_info (line 83) | def process_vision_info(messages: list[dict]) -> list[Image.Image]:
function format_data (line 102) | def format_data(samples: dict[str, any]) -> dict[str, list]:
function prepare_dataset (line 127) | def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetD...
function main (line 143) | def main():
FILE: scripts/add_copyrights.py
function get_tracked_python_files (line 37) | def get_tracked_python_files():
function check_and_add_copyright (line 52) | def check_and_add_copyright(file_path):
function main (line 73) | def main():
FILE: scripts/generate_harmony_dataset.py
class ScriptArguments (line 22) | class ScriptArguments:
function main (line 49) | def main(test_size, push_to_hub, repo_id):
FILE: scripts/generate_tiny_models.py
function push_to_hub (line 106) | def push_to_hub(model, tokenizer, generation_config, prefix=None, suffix...
function init_weights_tiny_model (line 127) | def init_weights_tiny_model(model):
FILE: scripts/generate_toolcall_dataset.py
class ScriptArguments (line 24) | class ScriptArguments:
function main (line 51) | def main(test_size, push_to_hub, repo_id):
FILE: scripts/generate_zen_dataset.py
class ScriptArguments (line 22) | class ScriptArguments:
function main (line 49) | def main(test_size, push_to_hub, repo_id):
FILE: scripts/generate_zen_image_dataset.py
class ScriptArguments (line 26) | class ScriptArguments:
function main (line 53) | def main(test_size, push_to_hub, repo_id):
FILE: scripts/generate_zen_multi_image_dataset.py
class ScriptArguments (line 26) | class ScriptArguments:
function main (line 53) | def main(test_size, push_to_hub, repo_id):
FILE: scripts/log_reports.py
function process_log_file (line 34) | def process_log_file(log):
function main (line 65) | def main(slack_channel_name):
FILE: tests/conftest.py
function apply_model_revisions (line 44) | def apply_model_revisions(monkeypatch):
function cleanup_gpu (line 77) | def cleanup_gpu():
FILE: tests/distributed/test_distributed.py
function run_command (line 30) | def run_command(command: list[str], env: dict[str, str]) -> None:
function get_config_path (line 36) | def get_config_path(lazy_shared_datadir):
class TestDistributed (line 44) | class TestDistributed(
method test_sft (line 68) | def test_sft(self, config, get_config_path):
method test_dpo (line 103) | def test_dpo(self, config, get_config_path):
method test_sft_dataset_streaming (line 138) | def test_sft_dataset_streaming(self, config, get_config_path):
method test_sft_peft (line 175) | def test_sft_peft(self, config, get_config_path):
method test_reward (line 211) | def test_reward(self, config, get_config_path):
method test_rloo (line 240) | def test_rloo(self, config, get_config_path):
method test_grpo (line 270) | def test_grpo(self, config, get_config_path):
FILE: tests/experimental/test_async_grpo_trainer.py
function dummy_reward_func (line 29) | def dummy_reward_func(completions, **kwargs):
class _StubRolloutWorker (line 33) | class _StubRolloutWorker:
method __init__ (line 36) | def __init__(self, tokenizer, dataset, num_generations: int = 8, sampl...
method _make_sample_iter (line 42) | def _make_sample_iter(self, tokenizer, dataset, num_generations):
method _fill_queue (line 70) | def _fill_queue(self):
method start (line 74) | def start(self):
method update_model_version (line 77) | def update_model_version(self, version):
method stop (line 81) | def stop(self):
method pause (line 84) | def pause(self):
method resume (line 87) | def resume(self):
method send_weights (line 90) | def send_weights(self, iterator):
class TestAsyncGRPOTrainer (line 94) | class TestAsyncGRPOTrainer(TrlTestCase):
method test_init_minimal (line 95) | def test_init_minimal(self):
method test_training (line 106) | def test_training(self):
FILE: tests/experimental/test_bco_trainer.py
class TestBCOTrainer (line 35) | class TestBCOTrainer(TrlTestCase):
method test_train (line 48) | def test_train(self, config_name):
method test_train_with_precompute (line 84) | def test_train_with_precompute(self):
method test_train_eval (line 121) | def test_train_eval(self):
method test_init_with_ref_model_is_model (line 149) | def test_init_with_ref_model_is_model(self):
method test_tokenize_and_process_tokens (line 172) | def test_tokenize_and_process_tokens(self):
method test_train_without_providing_ref_model (line 226) | def test_train_without_providing_ref_model(self):
method test_train_udm (line 260) | def test_train_udm(self):
method test_train_without_providing_ref_model_with_lora (line 310) | def test_train_without_providing_ref_model_with_lora(self):
method test_generate_during_eval_no_wandb (line 348) | def test_generate_during_eval_no_wandb(self):
method test_lora_train_and_save (line 379) | def test_lora_train_and_save(self):
method test_compute_metrics (line 411) | def test_compute_metrics(self):
FILE: tests/experimental/test_cpo_trainer.py
class TestCPOTrainer (line 25) | class TestCPOTrainer(TrlTestCase):
method setup_method (line 26) | def setup_method(self):
method test_cpo_trainer (line 48) | def test_cpo_trainer(self, name, loss_type, config_name):
method test_cpo_trainer_with_lora (line 103) | def test_cpo_trainer_with_lora(self, config_name):
method test_compute_metrics (line 151) | def test_compute_metrics(self):
method test_alphapo_trainer (line 181) | def test_alphapo_trainer(self):
FILE: tests/experimental/test_dppo_trainer.py
class TestDPPODivergenceMask (line 24) | class TestDPPODivergenceMask:
method make_trainer (line 28) | def make_trainer(divergence_type="binary_tv", epsilon=0.2, epsilon_hig...
method compute_divergence_mask (line 41) | def compute_divergence_mask(
method test_binary_tv_no_masking_within_threshold (line 60) | def test_binary_tv_no_masking_within_threshold(self):
method test_binary_tv_masks_positive_advantage_high_divergence (line 72) | def test_binary_tv_masks_positive_advantage_high_divergence(self):
method test_binary_tv_masks_negative_advantage_low_divergence (line 83) | def test_binary_tv_masks_negative_advantage_low_divergence(self):
method test_binary_tv_respects_completion_mask (line 94) | def test_binary_tv_respects_completion_mask(self):
method test_topk_tv_requires_topk_inputs (line 105) | def test_topk_tv_requires_topk_inputs(self):
class TestDPPOTrainer (line 132) | class TestDPPOTrainer(TrlTestCase):
method test_training_binary (line 134) | def test_training_binary(self, divergence_type):
method test_training_conversational (line 164) | def test_training_conversational(self, config_name):
FILE: tests/experimental/test_gkd_trainer.py
class TestGKDTrainerGenerateOnPolicy (line 28) | class TestGKDTrainerGenerateOnPolicy(TrlTestCase):
method setup_class (line 30) | def setup_class(cls):
method test_generate_on_policy_outputs_deterministic (line 43) | def test_generate_on_policy_outputs_deterministic(self):
method test_generate_on_policy_outputs (line 91) | def test_generate_on_policy_outputs(self):
class TestGeneralizedJSDLoss (line 126) | class TestGeneralizedJSDLoss(TrlTestCase):
method setup_method (line 127) | def setup_method(self):
method test_uniform_distribution (line 134) | def test_uniform_distribution(self):
method test_generalized_jsd_loss_edge_cases (line 139) | def test_generalized_jsd_loss_edge_cases(self):
method test_output_shape (line 158) | def test_output_shape(self):
method test_beta_values (line 163) | def test_beta_values(self):
method test_temperature_scaling (line 168) | def test_temperature_scaling(self):
method test_reduction_methods (line 173) | def test_reduction_methods(self):
method test_symmetry (line 186) | def test_symmetry(self):
method test_zero_loss_for_identical_inputs (line 195) | def test_zero_loss_for_identical_inputs(self):
class TestGKDTrainer (line 201) | class TestGKDTrainer(TrlTestCase):
method setup_method (line 202) | def setup_method(self):
method test_gkd_trainer (line 209) | def test_gkd_trainer(self):
method test_gkd_trainer_with_liger (line 240) | def test_gkd_trainer_with_liger(self):
method test_generation_config_init (line 265) | def test_generation_config_init(self):
FILE: tests/experimental/test_gold_trainer.py
function openr1_examples (line 27) | def openr1_examples():
function countdown_examples (line 40) | def countdown_examples():
function _teacher_inputs_from_collator (line 52) | def _teacher_inputs_from_collator(student_tok, teacher_tok, batch):
function _assert_alignment_covers_completion (line 76) | def _assert_alignment_covers_completion(loss_fn, batch, teacher_input_id...
function test_chatml_collator_preserves_completion_llama (line 97) | def test_chatml_collator_preserves_completion_llama(llama_tokenizer, qwe...
function test_chatml_collator_preserves_completion_llama_countdown (line 142) | def test_chatml_collator_preserves_completion_llama_countdown(llama_toke...
function test_chatml_collator_preserves_completion_smollm (line 187) | def test_chatml_collator_preserves_completion_smollm(smollm_tokenizer, q...
function build_config (line 231) | def build_config(**overrides):
function llama_tokenizer (line 250) | def llama_tokenizer():
function qwen_tokenizer (line 258) | def qwen_tokenizer():
function smollm_tokenizer (line 266) | def smollm_tokenizer():
function encode_prompt_completion (line 273) | def encode_prompt_completion(tokenizer, prompt, completion):
function pad_tokens (line 284) | def pad_tokens(ids, pad_id, target_length):
function pad_labels (line 288) | def pad_labels(labels, target_length):
function test_process_completions_to_buffer_left_pads_prompt_retokenization (line 292) | def test_process_completions_to_buffer_left_pads_prompt_retokenization():
function test_alignment_groups_cover_all_tokens (line 374) | def test_alignment_groups_cover_all_tokens(llama_tokenizer, qwen_tokeniz...
function test_merge_probabilities_multiplies_split_tokens (line 389) | def test_merge_probabilities_multiplies_split_tokens():
function test_initialize_vocabulary_mapping_contains_common_tokens (line 411) | def test_initialize_vocabulary_mapping_contains_common_tokens(llama_toke...
function test_get_start_and_size_answers_skips_prompt_tokens (line 431) | def test_get_start_and_size_answers_skips_prompt_tokens():
function test_generate_on_policy_outputs_masks_prompt (line 450) | def test_generate_on_policy_outputs_masks_prompt(llama_tokenizer):
function test_generate_on_policy_outputs_masks_prompt_smollm (line 502) | def test_generate_on_policy_outputs_masks_prompt_smollm(smollm_tokenizer...
function test_generalized_jsd_loss_accepts_probability_inputs (line 550) | def test_generalized_jsd_loss_accepts_probability_inputs():
function test_uldloss_handles_llama_student_qwen_teacher_sequence (line 570) | def test_uldloss_handles_llama_student_qwen_teacher_sequence(llama_token...
function test_uldloss_handles_smollm_student_qwen_teacher_sequence (line 619) | def test_uldloss_handles_smollm_student_qwen_teacher_sequence(smollm_tok...
function test_uldloss_hybrid_config_beta_zero (line 668) | def test_uldloss_hybrid_config_beta_zero(llama_tokenizer, qwen_tokenizer):
FILE: tests/experimental/test_grpo_with_replay_buffer_trainer.py
class TestReplayBuffer (line 29) | class TestReplayBuffer:
method setup_method (line 30) | def setup_method(self):
method test_add (line 33) | def test_add(self):
method test_add_more_than_maxlen (line 53) | def test_add_more_than_maxlen(self):
method test_sample (line 75) | def test_sample(self):
class TestUpdateWithReplayBuffer (line 97) | class TestUpdateWithReplayBuffer:
method setup_method (line 98) | def setup_method(self):
method _prepopulate_buffer (line 112) | def _prepopulate_buffer(self, with_pixels=False, with_logprobs=False):
method _make_inputs (line 136) | def _make_inputs(self, group_advantages, with_pixels=False, with_logpr...
method test_update_with_replay_buffer_no_variance (line 149) | def test_update_with_replay_buffer_no_variance(self):
method test_update_with_replay_buffer_with_variance (line 164) | def test_update_with_replay_buffer_with_variance(self):
method test_update_with_mixed_variance (line 174) | def test_update_with_mixed_variance(self):
method test_update_with_inputs_different_seq_len (line 193) | def test_update_with_inputs_different_seq_len(self):
class TestGRPOWithReplayBufferTrainer (line 255) | class TestGRPOWithReplayBufferTrainer(TrlTestCase):
method test_training_with_replay_buffer (line 256) | def test_training_with_replay_buffer(self, scale_rewards):
FILE: tests/experimental/test_gspo_token_trainer.py
class TestGSPOTokenTrainer (line 30) | class TestGSPOTokenTrainer(TrlTestCase):
method test_training (line 31) | def test_training(self):
FILE: tests/experimental/test_judges.py
class RandomBinaryJudge (line 28) | class RandomBinaryJudge(BaseBinaryJudge):
method judge (line 33) | def judge(self, prompts, completions, gold_completions=None, shuffle_o...
class TestJudges (line 37) | class TestJudges(TrlTestCase):
method _get_prompts_and_pairwise_completions (line 38) | def _get_prompts_and_pairwise_completions(self):
method _get_prompts_and_single_completions (line 43) | def _get_prompts_and_single_completions(self):
method test_all_true_judge (line 48) | def test_all_true_judge(self):
method test_hugging_face_judge (line 56) | def test_hugging_face_judge(self):
method load_pair_rm_judge (line 64) | def load_pair_rm_judge(self):
method test_pair_rm_judge (line 83) | def test_pair_rm_judge(self):
method test_pair_rm_judge_return_scores (line 100) | def test_pair_rm_judge_return_scores(self):
FILE: tests/experimental/test_kto_trainer.py
class TestKTOTrainer (line 27) | class TestKTOTrainer(TrlTestCase):
method setup_method (line 28) | def setup_method(self):
method test_kto_trainer (line 44) | def test_kto_trainer(self, config_name, loss_type, pre_compute, eval_d...
method test_kto_trainer_with_ref_model_is_model (line 82) | def test_kto_trainer_with_ref_model_is_model(self):
method test_tokenize_and_process_tokens (line 101) | def test_tokenize_and_process_tokens(self):
method test_kto_trainer_without_providing_ref_model (line 175) | def test_kto_trainer_without_providing_ref_model(self):
method test_kto_trainer_without_providing_ref_model_with_lora (line 212) | def test_kto_trainer_without_providing_ref_model_with_lora(self):
method test_kto_trainer_generate_during_eval_no_wandb (line 261) | def test_kto_trainer_generate_during_eval_no_wandb(self):
method test_kto_trainer_with_liger (line 292) | def test_kto_trainer_with_liger(self):
method test_compute_metrics (line 322) | def test_compute_metrics(self):
FILE: tests/experimental/test_merge_model_callback.py
class TestMergeModelCallback (line 29) | class TestMergeModelCallback(TrlTestCase):
method setup_method (line 30) | def setup_method(self):
method test_callback (line 37) | def test_callback(self):
method test_every_checkpoint (line 59) | def test_every_checkpoint(self):
FILE: tests/experimental/test_minillm_trainer.py
class TestMiniLLMTrainer (line 25) | class TestMiniLLMTrainer(TrlTestCase):
method test_train (line 26) | def test_train(self):
FILE: tests/experimental/test_modeling_value_head.py
class TestReferenceModel (line 24) | class TestReferenceModel(TrlTestCase):
method setup_method (line 25) | def setup_method(self):
method test_independent_reference (line 31) | def test_independent_reference(self):
method test_shared_layers (line 65) | def test_shared_layers(self):
FILE: tests/experimental/test_nash_md_trainer.py
class TestGeometricMixtureWrapper (line 33) | class TestGeometricMixtureWrapper(TrlTestCase):
method setup_method (line 34) | def setup_method(self):
method test_forward (line 45) | def test_forward(self):
method test_mixture_coefficient (line 55) | def test_mixture_coefficient(self):
method test_prepare_inputs_for_generation (line 70) | def test_prepare_inputs_for_generation(self):
class TestNashMDTrainer (line 81) | class TestNashMDTrainer(TrlTestCase):
method setup_method (line 82) | def setup_method(self):
method test_nash_md_trainer_training (line 91) | def test_nash_md_trainer_training(self, config_name):
method test_training_with_peft (line 120) | def test_training_with_peft(self):
method test_training_with_peft_and_ref_model (line 148) | def test_training_with_peft_and_ref_model(self):
method test_training_pre_pefted_model_implicit_ref_with_reward_model (line 177) | def test_training_pre_pefted_model_implicit_ref_with_reward_model(self):
method test_nash_md_trainer_judge_training (line 209) | def test_nash_md_trainer_judge_training(self, config_name):
FILE: tests/experimental/test_online_dpo_trainer.py
class TestOnlineDPOTrainer (line 44) | class TestOnlineDPOTrainer(TrlTestCase):
method setup_method (line 45) | def setup_method(self):
method test_training (line 58) | def test_training(self, config_name):
method test_training_model_str (line 83) | def test_training_model_str(self):
method test_training_with_ref_model (line 108) | def test_training_with_ref_model(self):
method test_ref_model_is_model (line 134) | def test_ref_model_is_model(self):
method test_training_with_peft (line 156) | def test_training_with_peft(self):
method test_training_with_peft_and_ref_model (line 185) | def test_training_with_peft_and_ref_model(self):
method test_training_with_judge (line 216) | def test_training_with_judge(self, config_name):
method test_training_with_vllm_server (line 244) | def test_training_with_vllm_server(self, config_name):
method test_training_with_vllm_colocate (line 285) | def test_training_with_vllm_colocate(self):
method test_vllm_config_validation (line 344) | def test_vllm_config_validation(self):
method test_generation_config_setup (line 371) | def test_generation_config_setup(self):
method test_training_with_transformers_paged (line 410) | def test_training_with_transformers_paged(self, config_name):
method test_training_with_reward_funcs (line 439) | def test_training_with_reward_funcs(self, config_name):
class TestOnlineDPOVisionTrainer (line 472) | class TestOnlineDPOVisionTrainer(TrlTestCase):
method test_online_dpo_vlm_trainer (line 480) | def test_online_dpo_vlm_trainer(self, model_id):
FILE: tests/experimental/test_orpo_trainer.py
class TestORPOTrainer (line 25) | class TestORPOTrainer(TrlTestCase):
method setup_method (line 26) | def setup_method(self):
method test_orpo_trainer (line 45) | def test_orpo_trainer(self, name, config_name):
method test_orpo_trainer_with_lora (line 98) | def test_orpo_trainer_with_lora(self, config_name):
method test_compute_metrics (line 145) | def test_compute_metrics(self):
FILE: tests/experimental/test_ppo_trainer.py
class TestBatchGeneration (line 74) | class TestBatchGeneration(TrlTestCase):
method setup_method (line 75) | def setup_method(self):
method test_mini_batch_generation (line 95) | def test_mini_batch_generation(self):
method test_single_batch_generation (line 114) | def test_single_batch_generation(self):
class BaseTester (line 134) | class BaseTester:
class VHeadModelTester (line 135) | class VHeadModelTester(TrlTestCase):
method setup_method (line 140) | def setup_method(self):
method test_value_head (line 143) | def test_value_head(self):
method test_value_head_shape (line 151) | def test_value_head_shape(self):
method test_value_head_init_random (line 159) | def test_value_head_init_random(self):
method test_value_head_not_str (line 168) | def test_value_head_not_str(self):
method test_from_save_trl (line 178) | def test_from_save_trl(self):
method test_from_save_trl_sharded (line 194) | def test_from_save_trl_sharded(self):
method test_from_save_transformers_sharded (line 209) | def test_from_save_transformers_sharded(self):
method test_from_save_transformers (line 229) | def test_from_save_transformers(self):
class TestCausalLMValueHeadModel (line 264) | class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
method teardown_method (line 273) | def teardown_method(self):
method test_inference (line 277) | def test_inference(self):
method test_dropout_config (line 293) | def test_dropout_config(self):
method test_dropout_kwargs (line 305) | def test_dropout_kwargs(self):
method test_generate (line 323) | def test_generate(self, model_name):
method test_transformers_bf16_kwargs (line 334) | def test_transformers_bf16_kwargs(self):
method test_push_to_hub (line 359) | def test_push_to_hub(self):
class TestSeq2SeqValueHeadModel (line 378) | class TestSeq2SeqValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
method teardown_method (line 387) | def teardown_method(self):
method test_inference (line 391) | def test_inference(self):
method test_dropout_config (line 408) | def test_dropout_config(self):
method test_dropout_kwargs (line 420) | def test_dropout_kwargs(self):
method test_generate (line 438) | def test_generate(self, model_name):
method test_push_to_hub (line 451) | def test_push_to_hub(self):
method test_transformers_bf16_kwargs (line 469) | def test_transformers_bf16_kwargs(self):
class TestPeftModel (line 493) | class TestPeftModel(TrlTestCase):
method setup_method (line 494) | def setup_method(self):
method test_create_peft_model (line 504) | def test_create_peft_model(self):
method test_peft_requires_grad (line 513) | def test_peft_requires_grad(self):
method test_check_peft_model_nb_trainable_params (line 525) | def test_check_peft_model_nb_trainable_params(self):
method test_create_peft_model_from_config (line 543) | def test_create_peft_model_from_config(self):
method test_create_bnb_peft_model_from_config (line 562) | def test_create_bnb_peft_model_from_config(self):
method test_save_pretrained_peft (line 588) | def test_save_pretrained_peft(self):
method test_load_pretrained_peft (line 622) | def test_load_pretrained_peft(self):
method test_continue_training_peft_model (line 647) | def test_continue_training_peft_model(self):
class TestCore (line 662) | class TestCore(TrlTestCase):
method setup_method (line 667) | def setup_method(self):
method test_masked_mean (line 672) | def test_masked_mean(self):
method test_masked_var (line 675) | def test_masked_var(self):
method test_masked_whiten (line 678) | def test_masked_whiten(self):
class TestPPOTrainer (line 689) | class TestPPOTrainer(TrlTestCase):
method setup_method (line 690) | def setup_method(self):
method test_basic_training (line 715) | def test_basic_training(self):
method test_peft_training (line 767) | def test_peft_training(self):
FILE: tests/experimental/test_prm_trainer.py
class TestComputeAccuracy (line 34) | class TestComputeAccuracy(TrlTestCase):
method test_token_classification_task (line 35) | def test_token_classification_task(self):
method test_token_classification_task_with_ignored_tokens_0 (line 49) | def test_token_classification_task_with_ignored_tokens_0(self):
method test_token_classification_task_with_ignored_tokens_1 (line 63) | def test_token_classification_task_with_ignored_tokens_1(self):
method test_rewards_comparison_task (line 77) | def test_rewards_comparison_task(self, caplog):
class TestTokenizeRow (line 101) | class TestTokenizeRow(TrlTestCase):
method setup_method (line 102) | def setup_method(self):
method test_tokenize_row_no_truncation (line 125) | def test_tokenize_row_no_truncation(self):
method test_tokenize_row_train_on_last_step_only (line 149) | def test_tokenize_row_train_on_last_step_only(self):
method test_tokenize_row_completion_truncation (line 172) | def test_tokenize_row_completion_truncation(self):
method test_tokenize_row_prompt_completion_truncation (line 196) | def test_tokenize_row_prompt_completion_truncation(self):
method test_tokenize_row_multi_token_separator (line 220) | def test_tokenize_row_multi_token_separator(self):
class TestPRMTrainer (line 245) | class TestPRMTrainer(TrlTestCase):
method setup_method (line 246) | def setup_method(self):
method test_train_full (line 252) | def test_train_full(self, train_on_last_step_only):
method test_train_full_pretokenized (line 272) | def test_train_full_pretokenized(self):
method test_train_lora (line 326) | def test_train_lora(self):
method test_tags (line 370) | def test_tags(self):
FILE: tests/experimental/test_utils.py
class TestDataCollatorForChatML (line 24) | class TestDataCollatorForChatML(TrlTestCase):
method setup_method (line 25) | def setup_method(self):
method test_data_collator_for_chatml (line 50) | def test_data_collator_for_chatml(self):
FILE: tests/experimental/test_winrate_callback.py
class HalfPairwiseJudge (line 30) | class HalfPairwiseJudge(BasePairwiseJudge):
method judge (line 33) | def judge(self, prompts, completions, shuffle_order=True, return_score...
class TrainerWithRefModel (line 41) | class TrainerWithRefModel(Trainer):
method __init__ (line 44) | def __init__(self, model, ref_model, args, train_dataset, eval_dataset...
class TestWinRateCallback (line 56) | class TestWinRateCallback(TrlTestCase):
method setup_method (line 57) | def setup_method(self):
method test_basic (line 86) | def test_basic(self):
method test_without_ref_model (line 112) | def test_without_ref_model(self):
method test_soft_judge (line 138) | def test_soft_judge(self):
method test_lora (line 182) | def test_lora(self):
FILE: tests/experimental/test_xpo_trainer.py
class TestXPOTrainer (line 31) | class TestXPOTrainer(TrlTestCase):
method setup_method (line 32) | def setup_method(self):
method test_xpo_trainer_training (line 41) | def test_xpo_trainer_training(self, config_name):
method test_training_with_peft (line 70) | def test_training_with_peft(self):
method test_training_with_peft_and_ref_model (line 98) | def test_training_with_peft_and_ref_model(self):
method test_training_pre_pefted_model_implicit_ref (line 127) | def test_training_pre_pefted_model_implicit_ref(self):
method test_xpo_trainer_judge_training (line 157) | def test_xpo_trainer_judge_training(self, config_name):
FILE: tests/experimental/testing_utils.py
class RandomPairwiseJudge (line 19) | class RandomPairwiseJudge(BasePairwiseJudge):
method judge (line 24) | def judge(self, prompts, completions, shuffle_order=True, return_score...
FILE: tests/test_activation_offloading.py
class TestActivationOffloading (line 30) | class TestActivationOffloading(TrlTestCase):
method test_offloading_with_peft_models (line 33) | def test_offloading_with_peft_models(self) -> None:
method test_noop_manager_with_offloading (line 80) | def test_noop_manager_with_offloading(self):
method test_min_offload_size (line 110) | def test_min_offload_size(self):
method test_real_hf_model (line 127) | def test_real_hf_model(self):
method test_tensor_deduplication (line 159) | def test_tensor_deduplication(self):
method test_parameter_filtering (line 200) | def test_parameter_filtering(self):
FILE: tests/test_callbacks.py
class TestLogCompletionsCallback (line 27) | class TestLogCompletionsCallback(TrlTestCase):
method setup_method (line 28) | def setup_method(self):
method test_basic_wandb (line 45) | def test_basic_wandb(self):
method test_basic_comet (line 82) | def test_basic_comet(self):
class TestBEMACallback (line 120) | class TestBEMACallback(TrlTestCase):
method setup_method (line 121) | def setup_method(self):
method test_model_saved (line 136) | def test_model_saved(self):
method test_update_frequency_0 (line 154) | def test_update_frequency_0(self):
method test_update_frequency_1 (line 174) | def test_update_frequency_1(self):
method test_update_frequency_2 (line 194) | def test_update_frequency_2(self):
method test_no_bema (line 214) | def test_no_bema(self):
method test_no_ema (line 227) | def test_no_ema(self):
FILE: tests/test_chat_template_utils.py
class TestCloneChatTemplate (line 33) | class TestCloneChatTemplate(TrlTestCase):
method test_clone (line 34) | def test_clone(self):
method test_clone_with_resize (line 45) | def test_clone_with_resize(self):
method test_clone_with_resize_and_extra_tokens_already_in_vocab (line 60) | def test_clone_with_resize_and_extra_tokens_already_in_vocab(self):
method test_apply_new_chat_template (line 80) | def test_apply_new_chat_template(self):
method test_clone_with_sequence_classification_model (line 99) | def test_clone_with_sequence_classification_model(self):
class TestAddResponseSchema (line 126) | class TestAddResponseSchema:
method test_add_response_schema (line 127) | def test_add_response_schema(self, tokenizer_name):
class TestIsChatTemplatePrefixPreserving (line 146) | class TestIsChatTemplatePrefixPreserving:
method test_prefix_preserving_template (line 147) | def test_prefix_preserving_template(self):
method test_non_prefix_preserving_template (line 165) | def test_non_prefix_preserving_template(self):
class TestGetTrainingChatTemplate (line 231) | class TestGetTrainingChatTemplate:
method test_new_chat_template_is_prefix_preserving (line 232) | def test_new_chat_template_is_prefix_preserving(self, tokenizer_name):
method test_behavior_unchanged_single_user_no_generation_prompt (line 238) | def test_behavior_unchanged_single_user_no_generation_prompt(self, tok...
method test_behavior_unchanged_single_user_with_generation_prompt (line 246) | def test_behavior_unchanged_single_user_with_generation_prompt(self, t...
method test_behavior_unchanged_single_user_and_final_assistant_plain_content (line 259) | def test_behavior_unchanged_single_user_and_final_assistant_plain_cont...
method test_behavior_unchanged_final_assistant_with_reasoning_content (line 270) | def test_behavior_unchanged_final_assistant_with_reasoning_content(sel...
method test_behavior_unchanged_final_assistant_with_existing_think_tags (line 285) | def test_behavior_unchanged_final_assistant_with_existing_think_tags(s...
method test_behavior_unchanged_assistant_with_tool_calls (line 299) | def test_behavior_unchanged_assistant_with_tool_calls(self, tokenizer_...
method test_behavior_unchanged_with_tools_with_and_without_system_message (line 314) | def test_behavior_unchanged_with_tools_with_and_without_system_message...
method test_behavior_unchanged_with_tools_with_system_message (line 339) | def test_behavior_unchanged_with_tools_with_system_message(self, token...
method test_behavior_unchanged_generation_prompt_with_enable_thinking_false (line 364) | def test_behavior_unchanged_generation_prompt_with_enable_thinking_fal...
class TestParseResponse (line 394) | class TestParseResponse:
method test_parse_response (line 395) | def test_parse_response(self, tokenizer_name):
method test_parse_response_with_reasoning_content (line 408) | def test_parse_response_with_reasoning_content(self, tokenizer_name):
method test_parse_response_tool_call (line 425) | def test_parse_response_tool_call(self, tokenizer_name):
method test_parse_response_tool_call_with_content (line 439) | def test_parse_response_tool_call_with_content(self, tokenizer_name):
method test_parse_response_tool_call_without_arguments (line 453) | def test_parse_response_tool_call_without_arguments(self, tokenizer_na...
method test_parse_response_multiple_tool_calls (line 467) | def test_parse_response_multiple_tool_calls(self, tokenizer_name):
method test_parse_response_malformed_tool_call (line 484) | def test_parse_response_malformed_tool_call(self, tokenizer_name):
FILE: tests/test_cli.py
function test_help_no_type_error (line 26) | def test_help_no_type_error(command):
class TestCLI (line 37) | class TestCLI(TrlTestCase):
method test_dpo (line 38) | def test_dpo(self):
method test_dpo_multiple_loss_types (line 45) | def test_dpo_multiple_loss_types(self):
method test_env (line 53) | def test_env(self, mock_stdout):
method test_grpo (line 61) | def test_grpo(self):
method test_kto (line 68) | def test_kto(self):
method test_reward (line 75) | def test_reward(self):
method test_rloo (line 82) | def test_rloo(self):
method test_sft (line 89) | def test_sft(self):
method test_sft_config_file (line 96) | def test_sft_config_file(self):
method test_vllm_serve_config_file (line 122) | def test_vllm_serve_config_file(self):
FILE: tests/test_cli_utils.py
class MyDataclass (line 29) | class MyDataclass:
class InvalidDataclass (line 35) | class InvalidDataclass:
class TestTrlParser (line 39) | class TestTrlParser(TrlTestCase):
method test_init_without_config_field (line 40) | def test_init_without_config_field(self):
method test_init_with_config_field (line 45) | def test_init_with_config_field(self):
method test_parse_args_and_config_with_valid_config (line 53) | def test_parse_args_and_config_with_valid_config(self, mock_environ, m...
method test_parse_args_and_arg_override_config (line 80) | def test_parse_args_and_arg_override_config(self, mock_yaml_load):
method test_parse_args_and_config_with_invalid_env (line 98) | def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load):
method test_parse_args_and_config_without_config (line 109) | def test_parse_args_and_config_without_config(self):
method test_set_defaults_with_config (line 124) | def test_set_defaults_with_config(self):
method test_parse_args_and_config_with_remaining_strings (line 137) | def test_parse_args_and_config_with_remaining_strings(self):
method test_parse_args_and_config_with_remaining_strings_in_config_and_args (line 154) | def test_parse_args_and_config_with_remaining_strings_in_config_and_ar...
method test_subparsers_with_config_defaults (line 172) | def test_subparsers_with_config_defaults(self, mock_yaml_load):
method test_subparsers_with_config_defaults_and_arg_override (line 198) | def test_subparsers_with_config_defaults_and_arg_override(self, mock_y...
method test_subparsers_with_config_defaults_and_arg_override_wrong_name (line 221) | def test_subparsers_with_config_defaults_and_arg_override_wrong_name(s...
method test_subparsers_multiple_with_config_defaults (line 243) | def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load):
class TestGetDataset (line 270) | class TestGetDataset:
method test_single_dataset_with_config (line 271) | def test_single_dataset_with_config(self):
method test_single_dataset_preference_config (line 279) | def test_single_dataset_preference_config(self):
method test_single_dataset_streaming (line 287) | def test_single_dataset_streaming(self):
method test_dataset_mixture_basic (line 296) | def test_dataset_mixture_basic(self):
method test_dataset_mixture_with_weights (line 315) | def test_dataset_mixture_with_weights(self):
method test_dataset_mixture_with_test_split (line 336) | def test_dataset_mixture_with_test_split(self):
method test_empty_dataset_mixture_raises_error (line 348) | def test_empty_dataset_mixture_raises_error(self):
method test_mixture_multiple_different_configs (line 354) | def test_mixture_multiple_different_configs(self):
method test_trlparser_parses_yaml_config_correctly (line 367) | def test_trlparser_parses_yaml_config_correctly(self):
method test_trlparser_parses_yaml_and_loads_dataset (line 407) | def test_trlparser_parses_yaml_and_loads_dataset(self):
FILE: tests/test_data_utils.py
class TestPrepareMultimodalMessages (line 49) | class TestPrepareMultimodalMessages:
method test_basic_user_assistant_conversation (line 50) | def test_basic_user_assistant_conversation(self):
method test_first_user_message_gets_image (line 72) | def test_first_user_message_gets_image(self):
method test_multiple_images (line 100) | def test_multiple_images(self):
method test_system_message_transformation (line 127) | def test_system_message_transformation(self):
method test_already_prepared_messages_unchanged (line 150) | def test_already_prepared_messages_unchanged(self):
method test_mixed_prepared_and_unprepared_messages (line 178) | def test_mixed_prepared_and_unprepared_messages(self):
method test_message_with_tool_calling_turns (line 206) | def test_message_with_tool_calling_turns(self):
class TestPrepareMultimodalMessagesVLLM (line 250) | class TestPrepareMultimodalMessagesVLLM:
method test_single_image_conversion (line 251) | def test_single_image_conversion(self):
method test_mixed_content_conversion (line 274) | def test_mixed_content_conversion(self):
method test_no_images (line 291) | def test_no_images(self):
method test_multiple_messages (line 302) | def test_multiple_messages(self):
method test_deepcopy_integrity (line 323) | def test_deepcopy_integrity(self):
class TestIsConversational (line 341) | class TestIsConversational(TrlTestCase):
method test_conversational (line 469) | def test_conversational(self, example):
method test_non_conversational (line 473) | def test_non_conversational(self, example):
class TestIsConversationalFromValue (line 477) | class TestIsConversationalFromValue(TrlTestCase):
method test_positive_1 (line 478) | def test_positive_1(self):
method test_negative_1 (line 487) | def test_negative_1(self):
method test_negative_2 (line 496) | def test_negative_2(self):
class TestApplyChatTemplate (line 501) | class TestApplyChatTemplate(TrlTestCase):
method test_apply_chat_template (line 582) | def test_apply_chat_template(self, tokenizer_id, example):
method test_maybe_apply_chat_template (line 609) | def test_maybe_apply_chat_template(self, tokenizer_id, example):
method test_apply_chat_template_with_chat_template_kwargs (line 633) | def test_apply_chat_template_with_chat_template_kwargs(self):
method test_apply_chat_template_with_tools (line 656) | def test_apply_chat_template_with_tools(self):
class TestApplyChatTemplateHarmony (line 688) | class TestApplyChatTemplateHarmony(TrlTestCase):
method test_language_modeling (line 689) | def test_language_modeling(self):
method test_prompt_only (line 720) | def test_prompt_only(self):
method test_prompt_completion (line 750) | def test_prompt_completion(self):
method test_preference (line 785) | def test_preference(self):
method test_preference_with_implicit_prompt (line 825) | def test_preference_with_implicit_prompt(self):
method test_unpaired_preference (line 876) | def test_unpaired_preference(self):
class TestUnpairPreferenceDataset (line 914) | class TestUnpairPreferenceDataset(TrlTestCase):
method test_unpair_preference_dataset (line 931) | def test_unpair_preference_dataset(self):
method test_unpair_preference_dataset_dict (line 938) | def test_unpair_preference_dataset_dict(self):
method test_maybe_unpair_preference_dataset (line 946) | def test_maybe_unpair_preference_dataset(self):
method test_maybe_unpair_preference_dataset_dict (line 953) | def test_maybe_unpair_preference_dataset_dict(self):
method test_maybe_unpair_preference_dataset_already_paired (line 961) | def test_maybe_unpair_preference_dataset_already_paired(self):
method test_maybe_unpair_preference_dataset_dict_already_paired (line 968) | def test_maybe_unpair_preference_dataset_dict_already_paired(self):
class TestExtractPrompt (line 976) | class TestExtractPrompt(TrlTestCase):
method test_extract_prompt_conversational (line 1011) | def test_extract_prompt_conversational(self):
method test_maybe_extract_prompt_conversational (line 1018) | def test_maybe_extract_prompt_conversational(self):
method test_maybe_extract_prompt_conversational_already_explicit (line 1025) | def test_maybe_extract_prompt_conversational_already_explicit(self):
method test_extract_prompt_standard (line 1032) | def test_extract_prompt_standard(self):
method test_maybe_extract_prompt_standard (line 1039) | def test_maybe_extract_prompt_standard(self):
method test_maybe_extract_prompt_standard_already_explicit (line 1046) | def test_maybe_extract_prompt_standard_already_explicit(self):
class TestPackDatasetWrapped (line 1052) | class TestPackDatasetWrapped(TrlTestCase):
method test_with_dataset (line 1053) | def test_with_dataset(self):
method test_with_iterable_dataset (line 1070) | def test_with_iterable_dataset(self):
class TestPackDatasetBfd (line 1089) | class TestPackDatasetBfd(TrlTestCase):
method test_with_dataset (line 1090) | def test_with_dataset(self):
method test_with_iterable_dataset (line 1109) | def test_with_iterable_dataset(self):
method test_with_overlong_0 (line 1126) | def test_with_overlong_0(self):
method test_with_overlong_two_coluns (line 1139) | def test_with_overlong_two_coluns(self):
method test_with_non_power_of_2 (line 1154) | def test_with_non_power_of_2(self):
method test_default_no_split (line 1167) | def test_default_no_split(self):
method test_with_empty_sequences (line 1182) | def test_with_empty_sequences(self):
class TestTruncateExamples (line 1196) | class TestTruncateExamples(TrlTestCase):
method test_with_dataset (line 1197) | def test_with_dataset(self):
method test_with_iterable_dataset (line 1214) | def test_with_iterable_dataset(self):
method test_with_extra_column (line 1232) | def test_with_extra_column(self):
method test_with_keep_end (line 1248) | def test_with_keep_end(self):
method test_with_keep_end_and_zero_max_length (line 1261) | def test_with_keep_end_and_zero_max_length(self):
class TestMaybeConvertToChatML (line 1275) | class TestMaybeConvertToChatML(TrlTestCase):
method test_with_conversations_key (line 1276) | def test_with_conversations_key(self):
method test_without_conversations_key (line 1292) | def test_without_conversations_key(self):
method test_not_conversional (line 1304) | def test_not_conversional(self):
method test_already_chatml (line 1309) | def test_already_chatml(self):
FILE: tests/test_dpo_trainer.py
class TestDataCollatorForPreference (line 42) | class TestDataCollatorForPreference(TrlTestCase):
method test_padding_and_masks (line 43) | def test_padding_and_masks(self):
method test_optional_reference_logps (line 81) | def test_optional_reference_logps(self):
method test_with_pad_to_multiple_of (line 114) | def test_with_pad_to_multiple_of(self):
class TestDataCollatorForVisionPreference (line 135) | class TestDataCollatorForVisionPreference(TrlTestCase):
method test_mm_token_type_ids_shape (line 141) | def test_mm_token_type_ids_shape(self):
class TestDPOTrainer (line 167) | class TestDPOTrainer(TrlTestCase):
method test_train (line 176) | def test_train(self, model_id):
method test_train_gpt_oss (line 203) | def test_train_gpt_oss(self):
method test_train_model (line 231) | def test_train_model(self):
method test_train_loss_types (line 282) | def test_train_loss_types(self, loss_type):
method test_train_multi_loss_types (line 317) | def test_train_multi_loss_types(self):
method test_train_with_wpo (line 348) | def test_train_with_wpo(self):
method test_train_with_ld (line 379) | def test_train_with_ld(self):
method test_train_with_f_divergence (line 414) | def test_train_with_f_divergence(self, f_divergence_type):
method test_train_with_explicit_ref_model (line 445) | def test_train_with_explicit_ref_model(self):
method test_training_with_sync_ref_model (line 483) | def test_training_with_sync_ref_model(self):
method test_train_model_dtype (line 517) | def test_train_model_dtype(self):
method test_train_dense_with_peft_config_lora (line 553) | def test_train_dense_with_peft_config_lora(self):
method test_train_moe_with_peft_config (line 594) | def test_train_moe_with_peft_config(self):
method test_train_peft_model (line 635) | def test_train_peft_model(self):
method test_train_with_peft_config_and_gradient_checkpointing (line 679) | def test_train_with_peft_config_and_gradient_checkpointing(self):
method test_train_with_liger (line 721) | def test_train_with_liger(self):
method test_train_with_iterable_dataset (line 750) | def test_train_with_iterable_dataset(self):
method test_train_padding_free (line 781) | def test_train_padding_free(self):
method test_train_with_chat_template_kwargs (line 812) | def test_train_with_chat_template_kwargs(self):
method test_train_toolcall_data (line 866) | def test_train_toolcall_data(self):
method test_train_with_eval (line 894) | def test_train_with_eval(self):
method test_train_with_multiple_eval_dataset (line 913) | def test_train_with_multiple_eval_dataset(self):
method test_train_with_compute_metrics (line 932) | def test_train_with_compute_metrics(self):
method test_train_with_gradient_checkpointing (line 963) | def test_train_with_gradient_checkpointing(self):
method test_tag_added (line 992) | def test_tag_added(self):
method test_tag_added_peft (line 1006) | def test_tag_added_peft(self):
method test_train_vlm (line 1054) | def test_train_vlm(self, model_id):
method test_train_vlm_multi_image (line 1106) | def test_train_vlm_multi_image(self, model_id):
method test_train_vlm_gemma_3n (line 1143) | def test_train_vlm_gemma_3n(self):
method test_train_vlm_text_only_data (line 1186) | def test_train_vlm_text_only_data(self, model_id, dataset_config):
method test_train_vlm_with_max_length (line 1216) | def test_train_vlm_with_max_length(self):
method test_peft_with_quantization (line 1237) | def test_peft_with_quantization(self):
method test_train_vlm_keep_end_raises (line 1292) | def test_train_vlm_keep_end_raises(self):
FILE: tests/test_grpo_trainer.py
function multiply_tool (line 62) | def multiply_tool(a: int, b: int) -> int:
function async_multiply_tool (line 76) | async def async_multiply_tool(a: int, b: int) -> int:
class TestGetHighEntropyMask (line 90) | class TestGetHighEntropyMask(TrlTestCase):
method get_high_entropy_mask (line 91) | def get_high_entropy_mask(self, entropies, mask, threshold):
method test_compute_entropy_mask_0 (line 109) | def test_compute_entropy_mask_0(self):
method test_compute_entropy_mask_1 (line 121) | def test_compute_entropy_mask_1(self):
method test_compute_entropy_mask_lower_threshold (line 129) | def test_compute_entropy_mask_lower_threshold(self):
method test_compute_entropy_threshold_0 (line 137) | def test_compute_entropy_threshold_0(self):
method test_compute_entropy_threshold_1 (line 145) | def test_compute_entropy_threshold_1(self):
method test_compute_entropy_all_masked (line 153) | def test_compute_entropy_all_masked(self):
class TestGRPORolloutDispatch (line 162) | class TestGRPORolloutDispatch:
method _make_trainer (line 163) | def _make_trainer(self):
method test_generate_prefers_rollout_func (line 202) | def test_generate_prefers_rollout_func(self):
method test_generate_rollout_func_syncs_vllm_weights_when_needed (line 220) | def test_generate_rollout_func_syncs_vllm_weights_when_needed(self):
method test_generate_rollout_func_raises_when_required_keys_are_missing (line 233) | def test_generate_rollout_func_raises_when_required_keys_are_missing(s...
class TestGRPOTrainer (line 241) | class TestGRPOTrainer(TrlTestCase):
method test_init_minimal (line 242) | def test_init_minimal(self):
method test_training (line 252) | def test_training(self, config_name):
method test_training_loss_types (line 282) | def test_training_loss_types(self, loss_type):
method test_training_with_eval (line 314) | def test_training_with_eval(self):
method test_training_with_num_generations_eval (line 337) | def test_training_with_num_generations_eval(self):
method test_training_eval_on_start (line 364) | def test_training_eval_on_start(self):
method test_training_multiple_iterations (line 388) | def test_training_multiple_iterations(self):
method test_training_peft_config (line 419) | def test_training_peft_config(self):
method test_training_peft_model (line 455) | def test_training_peft_model(self):
method test_training_peft_with_gradient_checkpointing (line 495) | def test_training_peft_with_gradient_checkpointing(self):
method test_training_different_reward_model (line 531) | def test_training_different_reward_model(self):
method test_training_reward_func_standard (line 570) | def test_training_reward_func_standard(self):
method test_training_reward_func_conversational (line 604) | def test_training_reward_func_conversational(self):
method test_training_multiple_reward_funcs (line 639) | def test_training_multiple_reward_funcs(self):
method test_training_sync_and_async_reward_funcs (line 677) | def test_training_sync_and_async_reward_funcs(self):
method test_training_multiple_reward_funcs_with_None_output (line 718) | def test_training_multiple_reward_funcs_with_None_output(self):
method test_training_multiple_reward_funcs_with_weights (line 762) | def test_training_multiple_reward_funcs_with_weights(self):
method test_training_multiple_mixed_reward_funcs (line 806) | def test_training_multiple_mixed_reward_funcs(self):
method test_training_reward_func_additional_column (line 840) | def test_training_reward_func_additional_column(self):
method test_training_with_sync_ref_model (line 880) | def test_training_with_sync_ref_model(self):
method test_training_beta_non_zero (line 916) | def test_training_beta_non_zero(self):
method test_training_with_pad_to_multiple_of (line 945) | def test_training_with_pad_to_multiple_of(self):
method test_get_off_policy_mask (line 975) | def test_get_off_policy_mask(self):
method test_get_off_policy_mask_padding (line 1002) | def test_get_off_policy_mask_padding(self):
method test_training_with_off_policy_mask (line 1041) | def test_training_with_off_policy_mask(self):
method test_training_with_off_policy_mask_with_liger (line 1072) | def test_training_with_off_policy_mask_with_liger(self):
method test_compute_liger_loss_passes_vllm_is_ratio (line 1103) | def test_compute_liger_loss_passes_vllm_is_ratio(self):
method test_training_with_bias_correction_kl (line 1153) | def test_training_with_bias_correction_kl(self):
method test_training_with_cast_lm_head_to_fp32 (line 1188) | def test_training_with_cast_lm_head_to_fp32(self, model_name):
method test_training_with_entropy_filter (line 1217) | def test_training_with_entropy_filter(self):
method test_training_vllm_and_peft (line 1249) | def test_training_vllm_and_peft(self):
method test_training_vllm_structured_outputs (line 1296) | def test_training_vllm_structured_outputs(self):
method test_training_vllm_importance_sampling_correction (line 1330) | def test_training_vllm_importance_sampling_correction(self):
method test_training_with_additional_generation_kwargs (line 1363) | def test_training_with_additional_generation_kwargs(self):
method test_training_vllm_with_additional_generation_kwargs (line 1400) | def test_training_vllm_with_additional_generation_kwargs(self):
method test_training_normalize_then_sum_aggregation (line 1436) | def test_training_normalize_then_sum_aggregation(self):
method test_training_scale_rewards (line 1475) | def test_training_scale_rewards(self, scale_rewards):
method test_training_with_mask_truncated_completions (line 1506) | def test_training_with_mask_truncated_completions(self, mock_generate):
method test_training_with_mask_truncated_completions_all_masked (line 1555) | def test_training_with_mask_truncated_completions_all_masked(self):
method test_warning_raised_all_rewards_none (line 1593) | def test_warning_raised_all_rewards_none(self, caplog):
method test_training_num_generations_larger_than_batch_size (line 1622) | def test_training_num_generations_larger_than_batch_size(self):
method test_training_delta_clipping (line 1652) | def test_training_delta_clipping(self):
method test_training_multiple_dataloader_workers (line 1682) | def test_training_multiple_dataloader_workers(self):
method test_training_with_generation_kwargs (line 1723) | def test_training_with_generation_kwargs(self):
method test_training_with_reward_func_accessing_trainer_state (line 1754) | def test_training_with_reward_func_accessing_trainer_state(self):
method test_training_reward_func_with_log_extra (line 1779) | def test_training_reward_func_with_log_extra(self):
method test_training_reward_func_with_log_metric (line 1805) | def test_training_reward_func_with_log_metric(self):
method test_prepare_input_called_with_correct_data (line 1832) | def test_prepare_input_called_with_correct_data(self):
method test_training_vlm (line 1901) | def test_training_vlm(self, model_id):
method test_training_vlm_with_pad_to_multiple_of (line 1946) | def test_training_vlm_with_pad_to_multiple_of(self):
method test_training_vlm_beta_non_zero (line 1989) | def test_training_vlm_beta_non_zero(self, model_id):
method test_training_vlm_peft (line 2036) | def test_training_vlm_peft(self, model_id):
method test_training_vlm_and_importance_sampling (line 2082) | def test_training_vlm_and_importance_sampling(self, model_id):
method test_training_vlm_and_liger (line 2136) | def test_training_vlm_and_liger(self, model_id):
method test_training_vlm_and_vllm (line 2185) | def test_training_vlm_and_vllm(self, model_id) -> None:
method test_training_vlm_multi_image (line 2226) | def test_training_vlm_multi_image(self, model_id):
method test_training_sequence_importance_sampling (line 2261) | def test_training_sequence_importance_sampling(self):
method test_training_with_chat_template_kwargs (line 2292) | def test_training_with_chat_template_kwargs(self):
method test_training_with_tools (line 2329) | def test_training_with_tools(self, tools: list[Callable]):
method test_training_with_environment_factory (line 2418) | def test_training_with_environment_factory(self):
method test_training_with_malformed_tool_calls (line 2513) | def test_training_with_malformed_tool_calls(self):
method test_mismatched_reward_processing_classes_length (line 2562) | def test_mismatched_reward_processing_classes_length(self):
method test_correct_reward_processing_classes_list (line 2588) | def test_correct_reward_processing_classes_list(self):
method test_single_reward_model_with_single_processing_class (line 2619) | def test_single_reward_model_with_single_processing_class(self):
class TestGRPOTrainerSlow (line 2647) | class TestGRPOTrainerSlow(TrlTestCase):
method setup_method (line 2648) | def setup_method(self):
method teardown_method (line 2653) | def teardown_method(self):
method test_training_with_liger_grpo_kernel (line 2666) | def test_training_with_liger_grpo_kernel(self, model_name):
method test_training_with_liger_grpo_kernel_and_peft (line 2712) | def test_training_with_liger_grpo_kernel_and_peft(self, model_name):
method test_liger_grpo_kernel_importance_sampling (line 2773) | def test_liger_grpo_kernel_importance_sampling(self):
method test_training_with_transformers_paged (line 2820) | def test_training_with_transformers_paged(self, model_name):
method test_vlm_training (line 2867) | def test_vlm_training(self, model_name):
method test_vlm_processor_vllm_colocate_mode (line 2994) | def test_vlm_processor_vllm_colocate_mode(self):
method test_training_vllm (line 3142) | def test_training_vllm(self):
FILE: tests/test_model_utils.py
class TestDisableGradientCheckpointing (line 20) | class TestDisableGradientCheckpointing:
method test_when_disabled (line 21) | def test_when_disabled(self):
method test_when_enabled (line 28) | def test_when_enabled(self):
FILE: tests/test_reward_trainer.py
class TestDataCollatorForPreference (line 34) | class TestDataCollatorForPreference(TrlTestCase):
method test_basic_padding (line 35) | def test_basic_padding(self):
method test_pad_to_multiple_of (line 50) | def test_pad_to_multiple_of(self):
method test_single_example (line 67) | def test_single_example(self):
method test_different_pad_token_id (line 77) | def test_different_pad_token_id(self):
method test_collate_with_margin (line 94) | def test_collate_with_margin(self):
class TestRewardTrainer (line 110) | class TestRewardTrainer(TrlTestCase):
method test_raises_error_when_model_num_labels_not_one (line 111) | def test_raises_error_when_model_num_labels_not_one(self):
method test_train (line 135) | def test_train(self, model_id):
method test_train_dataset_types (line 166) | def test_train_dataset_types(self, config_name):
method test_train_model (line 192) | def test_train_model(self):
method test_train_from_sequence_classification_model (line 221) | def test_train_from_sequence_classification_model(self):
method test_train_model_dtype (line 247) | def test_train_model_dtype(self):
method test_train_dense_with_peft_config (line 285) | def test_train_dense_with_peft_config(self):
method test_train_moe_with_peft_config (line 322) | def test_train_moe_with_peft_config(self):
method test_train_peft_model (line 359) | def test_train_peft_model(self):
method test_train_with_peft_config_and_gradient_checkpointing (line 403) | def test_train_with_peft_config_and_gradient_checkpointing(self):
method test_train_with_peft_config_and_gradient_checkpointing_reentrant (line 441) | def test_train_with_peft_config_and_gradient_checkpointing_reentrant(s...
method test_train_with_pretokenized_data (line 489) | def test_train_with_pretokenized_data(self, chosen_column, rejected_co...
method test_train_with_iterable_dataset (line 531) | def test_train_with_iterable_dataset(self):
method test_train_with_chat_template_kwargs (line 559) | def test_train_with_chat_template_kwargs(self):
method test_train_with_set_chat_template_from_model (line 610) | def test_train_with_set_chat_template_from_model(self):
method test_train_with_set_chat_template_from_path (line 642) | def test_train_with_set_chat_template_from_path(self, lazy_shared_data...
method test_train_toolcall_data (line 688) | def test_train_toolcall_data(self):
method test_train_toolcall_data_as_json (line 714) | def test_train_toolcall_data_as_json(self):
method test_train_with_eval (line 748) | def test_train_with_eval(self):
method test_train_with_multiple_eval_dataset (line 767) | def test_train_with_multiple_eval_dataset(self):
method test_train_with_compute_metrics (line 786) | def test_train_with_compute_metrics(self):
method test_train_with_gradient_checkpointing (line 817) | def test_train_with_gradient_checkpointing(self):
method test_train_with_gradient_checkpointing_reentrant (line 844) | def test_train_with_gradient_checkpointing_reentrant(self, use_reentra...
method test_tag_added (line 875) | def test_tag_added(self):
method test_tag_added_peft (line 889) | def test_tag_added_peft(self):
method test_train_with_margin (line 903) | def test_train_with_margin(self):
method test_train_with_center_rewards_coefficient (line 935) | def test_train_with_center_rewards_coefficient(self):
FILE: tests/test_rewards.py
class TestThinkFormatReward (line 22) | class TestThinkFormatReward(TrlTestCase):
method test_valid_format (line 23) | def test_valid_format(self):
method test_invalid_format (line 36) | def test_invalid_format(self):
method test_mixed_format (line 53) | def test_mixed_format(self):
class TestSoftOverlongPunishmentReward (line 66) | class TestSoftOverlongPunishmentReward:
method test_soft_overlong_punishment_short_completion (line 67) | def test_soft_overlong_punishment_short_completion(self):
method test_soft_overlong_punishment_long_completion (line 75) | def test_soft_overlong_punishment_long_completion(self):
method test_soft_overlong_punishment_intermediate_completion (line 83) | def test_soft_overlong_punishment_intermediate_completion(self):
class TestAccuracyReward (line 91) | class TestAccuracyReward:
method test_accuracy_reward_correct_answer (line 93) | def test_accuracy_reward_correct_answer(self):
method test_accuracy_reward_wrong_answer (line 102) | def test_accuracy_reward_wrong_answer(self):
method test_accuracy_reward_wrong_answer_no_latex (line 110) | def test_accuracy_reward_wrong_answer_no_latex(self):
method test_accuracy_reward_unparsable_gold (line 118) | def test_accuracy_reward_unparsable_gold(self):
method test_accuracy_reward_in_worker_thread (line 133) | def test_accuracy_reward_in_worker_thread(self):
class TestReasoningAccuracyReward (line 154) | class TestReasoningAccuracyReward:
method test_correct_answer_yields_unit_reward (line 156) | def test_correct_answer_yields_unit_reward(self):
method test_correct_answer_with_custom_tags_yields_unit_reward (line 167) | def test_correct_answer_with_custom_tags_yields_unit_reward(self):
method test_incorrect_answer_yields_zero_reward (line 178) | def test_incorrect_answer_yields_zero_reward(self):
method test_correct_answer_in_reasoning_yields_zero_reward (line 185) | def test_correct_answer_in_reasoning_yields_zero_reward(self):
method test_incomplete_reasoning_yields_zero_reward (line 196) | def test_incomplete_reasoning_yields_zero_reward(self):
method test_unparsable_gold_solution_yields_none_reward (line 207) | def test_unparsable_gold_solution_yields_none_reward(self):
FILE: tests/test_rich_progress_callback.py
class DummyModel (line 25) | class DummyModel(nn.Module):
method __init__ (line 26) | def __init__(self):
method forward (line 30) | def forward(self, x):
class TestRichProgressCallback (line 35) | class TestRichProgressCallback(TrlTestCase):
method setup_method (line 36) | def setup_method(self):
method test_rich_progress_callback_logging (line 41) | def test_rich_progress_callback_logging(self):
FILE: tests/test_rloo_trainer.py
class TestRLOOTrainer (line 39) | class TestRLOOTrainer(TrlTestCase):
method test_init_minimal (line 40) | def test_init_minimal(self):
method test_training (line 50) | def test_training(self, config_name):
method test_training_with_eval (line 79) | def test_training_with_eval(self):
method test_training_with_num_generations_eval (line 102) | def test_training_with_num_generations_eval(self):
method test_training_multiple_iterations (line 126) | def test_training_multiple_iterations(self):
method test_training_peft_config (line 157) | def test_training_peft_config(self):
method test_training_peft_model (line 193) | def test_training_peft_model(self):
method test_training_peft_with_gradient_checkpointing (line 233) | def test_training_peft_with_gradient_checkpointing(self):
method test_training_different_reward_model (line 269) | def test_training_different_reward_model(self):
method test_training_reward_func_standard (line 308) | def test_training_reward_func_standard(self):
method test_training_reward_func_conversational (line 342) | def test_training_reward_func_conversational(self):
method test_training_multiple_reward_funcs (line 377) | def test_training_multiple_reward_funcs(self):
method test_training_sync_and_async_reward_funcs (line 415) | def test_training_sync_and_async_reward_funcs(self):
method test_training_multiple_reward_funcs_with_None_output (line 456) | def test_training_multiple_reward_funcs_with_None_output(self):
method test_training_multiple_reward_funcs_with_weights (line 500) | def test_training_multiple_reward_funcs_with_weights(self):
method test_training_multiple_mixed_reward_funcs (line 544) | def test_training_multiple_mixed_reward_funcs(self):
method test_training_reward_func_additional_column (line 578) | def test_training_reward_func_additional_column(self):
method test_training_with_sync_ref_model (line 618) | def test_training_with_sync_ref_model(self):
method test_training_beta_zero (line 654) | def test_training_beta_zero(self):
method test_training_with_pad_to_multiple_of (line 683) | def test_training_with_pad_to_multiple_of(self):
method test_training_vllm_and_peft (line 716) | def test_training_vllm_and_peft(self):
method test_training_vllm_structured_outputs (line 763) | def test_training_vllm_structured_outputs(self):
method test_training_with_additional_generation_kwargs (line 795) | def test_training_with_additional_generation_kwargs(self):
method test_training_vllm_with_additional_generation_kwargs (line 832) | def test_training_vllm_with_additional_generation_kwargs(self):
method test_training_with_normalized_advantages (line 868) | def test_training_with_normalized_advantages(self):
method test_training_with_clipped_rewards (line 898) | def test_training_with_clipped_rewards(self):
method test_training_with_mask_truncated_completions (line 929) | def test_training_with_mask_truncated_completions(self, mock_generate):
method test_training_with_mask_truncated_completions_all_masked (line 978) | def test_training_with_mask_truncated_completions_all_masked(self):
method test_warning_raised_all_rewards_none (line 1016) | def test_warning_raised_all_rewards_none(self, caplog):
method test_training_num_generations_larger_than_batch_size (line 1045) | def test_training_num_generations_larger_than_batch_size(self):
method test_training_multiple_dataloader_workers (line 1075) | def test_training_multiple_dataloader_workers(self):
method test_training_with_generation_kwargs (line 1116) | def test_training_with_generation_kwargs(self):
method test_training_with_reward_func_accessing_trainer_state (line 1147) | def test_training_with_reward_func_accessing_trainer_state(self):
method test_training_reward_func_with_log_extra (line 1172) | def test_training_reward_func_with_log_extra(self):
method test_training_reward_func_with_log_metric (line 1198) | def test_training_reward_func_with_log_metric(self):
method test_prepare_input_called_with_correct_data (line 1225) | def test_prepare_input_called_with_correct_data(self):
method test_training_vlm (line 1294) | def test_training_vlm(self, model_id):
method test_training_vlm_with_pad_to_multiple_of (line 1338) | def test_training_vlm_with_pad_to_multiple_of(self):
method test_training_vlm_beta_non_zero (line 1381) | def test_training_vlm_beta_non_zero(self, model_id):
method test_training_vlm_peft (line 1428) | def test_training_vlm_peft(self, model_id):
method test_training_vlm_and_vllm (line 1477) | def test_training_vlm_and_vllm(self, model_id) -> None:
method test_training_vlm_multi_image (line 1518) | def test_training_vlm_multi_image(self, model_id):
method test_training_with_chat_template_kwargs (line 1550) | def test_training_with_chat_template_kwargs(self):
method test_mismatched_reward_processing_classes_length (line 1581) | def test_mismatched_reward_processing_classes_length(self):
method test_correct_reward_processing_classes_list (line 1607) | def test_correct_reward_processing_classes_list(self):
method test_single_reward_model_with_single_processing_class (line 1638) | def test_single_reward_model_with_single_processing_class(self):
FILE: tests/test_sft_trainer.py
class TestDFTLoss (line 61) | class TestDFTLoss(TrlTestCase):
method test_dft_loss (line 62) | def test_dft_loss(self):
class TestDataCollatorForLanguageModeling (line 84) | class TestDataCollatorForLanguageModeling(TrlTestCase):
method test_basic_padding (line 85) | def test_basic_padding(self):
method test_completion_mask (line 97) | def test_completion_mask(self):
method test_completion_only_loss_disabled (line 112) | def test_completion_only_loss_disabled(self):
method test_padding_free_mode (line 128) | def test_padding_free_mode(self):
method test_padding_free_with_completion_mask (line 140) | def test_padding_free_with_completion_mask(self):
method test_packing (line 155) | def test_packing(self):
method test_pad_to_multiple_of (line 172) | def test_pad_to_multiple_of(self):
method test_pad_to_multiple_of_and_padding_free (line 184) | def test_pad_to_multiple_of_and_padding_free(self):
method test_custom_position_ids_but_no_padding_free (line 196) | def test_custom_position_ids_but_no_padding_free(self):
method test_single_example (line 208) | def test_single_example(self):
method test_different_pad_token_id (line 220) | def test_different_pad_token_id(self):
method test_assistant_masks (line 232) | def test_assistant_masks(self):
method test_single_example_single_doc (line 246) | def test_single_example_single_doc(self):
method test_single_example_multiple_docs (line 252) | def test_single_example_multiple_docs(self):
method test_multiple_examples (line 259) | def test_multiple_examples(self):
class TestSFTTrainer (line 267) | class TestSFTTrainer(TrlTestCase):
method test_init_with_training_arguments (line 268) | def test_init_with_training_arguments(self):
method test_train (line 289) | def test_train(self, model_id):
method test_train_gpt_oss (line 312) | def test_train_gpt_oss(self):
method test_train_model (line 336) | def test_train_model(self):
method test_train_dft_loss (line 364) | def test_train_dft_loss(self):
method test_train_moe_model_with_aux_loss (line 398) | def test_train_moe_model_with_aux_loss(self):
method test_train_with_formatting_func (line 426) | def test_train_with_formatting_func(self):
method test_train_model_dtype (line 458) | def test_train_model_dtype(self):
method test_train_dense_with_peft_config_lora (line 494) | def test_train_dense_with_peft_config_lora(self):
method test_train_with_peft_config_prompt_tuning (line 539) | def test_train_with_peft_config_prompt_tuning(self, peft_type):
method test_train_moe_with_peft_config (line 597) | def test_train_moe_with_peft_config(self):
method test_train_peft_model (line 634) | def test_train_peft_model(self):
method test_train_with_peft_config_and_gradient_checkpointing (line 674) | def test_train_with_peft_config_and_gradient_checkpointing(self):
method test_train_with_peft_config_and_gradient_checkpointing_reentrant (line 712) | def test_train_with_peft_config_and_gradient_checkpointing_reentrant(s...
method test_train_with_liger (line 754) | def test_train_with_liger(self):
method test_compute_loss_skip_logits_on_eval_without_metrics_with_liger (line 780) | def test_compute_loss_skip_logits_on_eval_without_metrics_with_liger(s...
method test_predict_does_not_skip_logits_with_liger (line 822) | def test_predict_does_not_skip_logits_with_liger(self):
method test_train_with_non_chatml_conversational_data (line 854) | def test_train_with_non_chatml_conversational_data(self):
method test_train_with_pretokenized_data (line 884) | def test_train_with_pretokenized_data(self):
method test_train_with_iterable_dataset (line 914) | def test_train_with_iterable_dataset(self):
method test_train_padding_free (line 940) | def test_train_padding_free(self):
method test_train_packing (line 973) | def test_train_packing(self, packing_strategy):
method test_eval_packing (line 1001) | def test_eval_packing(self):
method test_only_train_packing (line 1035) | def test_only_train_packing(self):
method test_train_with_chat_template_kwargs (line 1068) | def test_train_with_chat_template_kwargs(self):
method test_train_assistant_only (line 1118) | def test_train_assistant_only(self):
method test_train_completion_only (line 1142) | def test_train_completion_only(self):
method test_train_completion_only_harmony (line 1166) | def test_train_completion_only_harmony(self):
method test_train_assistant_only_and_completion_only (line 1190) | def test_train_assistant_only_and_completion_only(self):
method test_train_assistant_only_iterable_dataset (line 1224) | def test_train_assistant_only_iterable_dataset(self):
method test_train_with_set_chat_template_from_model (line 1250) | def test_train_with_set_chat_template_from_model(self):
method test_train_with_set_chat_template_from_path (line 1275) | def test_train_with_set_chat_template_from_path(self, lazy_shared_data...
method test_train_toolcall_data (line 1314) | def test_train_toolcall_data(self):
method test_train_toolcall_data_as_json (line 1338) | def test_train_toolcall_data_as_json(self):
method test_train_with_eval (line 1371) | def test_train_with_eval(self):
method test_train_with_multiple_eval_dataset (line 1390) | def test_train_with_multiple_eval_dataset(self):
method test_train_with_compute_metrics (line 1409) | def test_train_with_compute_metrics(self):
method test_train_with_gradient_checkpointing (line 1440) | def test_train_with_gradient_checkpointing(self):
method test_train_with_gradient_checkpointing_reentrant (line 1465) | def test_train_with_gradient_checkpointing_reentrant(self, use_reentra...
method test_tag_added (line 1494) | def test_tag_added(self):
method test_tag_added_peft (line 1508) | def test_tag_added_peft(self):
method test_train_vlm (line 1556) | def test_train_vlm(self, model_id):
method test_train_vlm_multi_image (line 1607) | def test_train_vlm_multi_image(self, model_id):
method test_train_vlm_prompt_completion (line 1649) | def test_train_vlm_prompt_completion(self, model_id):
method test_train_vlm_gemma_3n (line 1685) | def test_train_vlm_gemma_3n(self):
method test_train_vlm_text_only_data (line 1728) | def test_train_vlm_text_only_data(self, model_id, dataset_config):
method test_prompt_tuning (line 1758) | def test_prompt_tuning(self):
method test_peft_with_quantization (line 1791) | def test_peft_with_quantization(self):
method test_prompt_tuning_peft_model (line 1846) | def test_prompt_tuning_peft_model(self):
class TestSFTTrainerSlow (line 1879) | class TestSFTTrainerSlow(TrlTestCase):
method setup_method (line 1880) | def setup_method(self):
method teardown_method (line 1892) | def teardown_method(self):
method test_sft_trainer_transformers_mp (line 1905) | def test_sft_trainer_transformers_mp(self, model_name, packing):
method test_sft_trainer_transformers_mp_gc_device_map (line 1949) | def test_sft_trainer_transformers_mp_gc_device_map(
method test_sft_trainer_transformers_mp_gc_peft_qlora (line 1997) | def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, p...
method test_sft_trainer_with_chat_format_qlora (line 2046) | def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
method test_sft_trainer_with_liger (line 2093) | def test_sft_trainer_with_liger(self, model_name, packing):
method test_train_offloading (line 2144) | def test_train_offloading(self, model_name, packing):
FILE: tests/test_skills.py
class TestGetTrlSkillsDir (line 23) | class TestGetTrlSkillsDir:
method test_returns_path_object (line 26) | def test_returns_path_object(self):
method test_directory_exists (line 31) | def test_directory_exists(self):
method test_is_directory (line 36) | def test_is_directory(self):
method test_contains_skills_module (line 41) | def test_contains_skills_module(self):
class TestListSkills (line 47) | class TestListSkills:
method test_returns_list (line 50) | def test_returns_list(self):
method test_contains_trl_training (line 55) | def test_contains_trl_training(self):
method test_skills_are_sorted (line 60) | def test_skills_are_sorted(self):
method test_with_custom_directory (line 65) | def test_with_custom_directory(self, tmp_path):
method test_empty_directory (line 77) | def test_empty_directory(self, tmp_path):
method test_nonexistent_directory (line 82) | def test_nonexistent_directory(self, tmp_path):
method test_ignores_files (line 88) | def test_ignores_files(self, tmp_path):
method test_requires_skill_md (line 97) | def test_requires_skill_md(self, tmp_path):
class TestInstallSkill (line 108) | class TestInstallSkill:
method test_basic_installation (line 111) | def test_basic_installation(self, tmp_path):
method test_creates_target_directory (line 121) | def test_creates_target_directory(self, tmp_path):
method test_skill_not_found (line 130) | def test_skill_not_found(self, tmp_path):
method test_skill_already_exists_without_force (line 137) | def test_skill_already_exists_without_force(self, tmp_path):
method test_force_overwrites_existing (line 148) | def test_force_overwrites_existing(self, tmp_path):
method test_force_overwrites_symlink (line 166) | def test_force_overwrites_symlink(self, tmp_path):
method test_skill_not_directory (line 182) | def test_skill_not_directory(self, tmp_path):
method test_preserves_directory_structure (line 194) | def test_preserves_directory_structure(self, tmp_path):
method test_install_to_same_directory_fails (line 212) | def test_install_to_same_directory_fails(self, tmp_path):
class TestUninstallSkill (line 227) | class TestUninstallSkill:
method test_basic_uninstallation (line 230) | def test_basic_uninstallation(self, tmp_path):
method test_skill_not_installed (line 244) | def test_skill_not_installed(self, tmp_path):
method test_uninstall_from_nonexistent_directory (line 252) | def test_uninstall_from_nonexistent_directory(self, tmp_path):
method test_uninstall_removes_all_contents (line 259) | def test_uninstall_removes_all_contents(self, tmp_path):
method test_uninstall_doesnt_affect_other_skills (line 280) | def test_uninstall_doesnt_affect_other_skills(self, tmp_path):
class TestIntegration (line 303) | class TestIntegration:
method test_full_workflow (line 306) | def test_full_workflow(self, tmp_path):
method test_install_uninstall_cycle (line 336) | def test_install_uninstall_cycle(self, tmp_path):
method test_force_reinstall_workflow (line 354) | def test_force_reinstall_workflow(self, tmp_path):
class TestEdgeCases (line 376) | class TestEdgeCases:
method test_skill_with_special_characters_in_name (line 379) | def test_skill_with_special_characters_in_name(self, tmp_path):
method test_empty_skill_directory (line 397) | def test_empty_skill_directory(self, tmp_path):
method test_skill_with_hidden_files (line 414) | def test_skill_with_hidden_files(self, tmp_path):
method test_list_skills_with_symlinks (line 429) | def test_list_skills_with_symlinks(self, tmp_path):
class TestListAgentNames (line 448) | class TestListAgentNames:
method test_returns_list (line 451) | def test_returns_list(self):
method test_contains_expected_agents (line 456) | def test_contains_expected_agents(self):
method test_agents_are_sorted (line 463) | def test_agents_are_sorted(self):
class TestResolveTargetPath (line 469) | class TestResolveTargetPath:
method test_resolve_agent_name_project_scope (line 472) | def test_resolve_agent_name_project_scope(self):
method test_resolve_agent_name_global_scope (line 477) | def test_resolve_agent_name_global_scope(self):
method test_resolve_custom_path_string (line 482) | def test_resolve_custom_path_string(self):
method test_resolve_custom_path_object (line 487) | def test_resolve_custom_path_object(self):
method test_resolve_path_with_tilde (line 493) | def test_resolve_path_with_tilde(self):
method test_all_predefined_agents (line 499) | def test_all_predefined_agents(self):
method test_invalid_scope_for_predefined_agent (line 507) | def test_invalid_scope_for_predefined_agent(self):
class TestHighLevelAPI (line 513) | class TestHighLevelAPI:
method test_list_skills_with_target_string (line 516) | def test_list_skills_with_target_string(self, tmp_path):
method test_list_skills_with_target_path (line 525) | def test_list_skills_with_target_path(self, tmp_path):
method test_list_skills_without_target (line 533) | def test_list_skills_without_target(self):
method test_install_skill_with_target_string (line 539) | def test_install_skill_with_target_string(self, tmp_path):
method test_install_skill_with_target_path (line 545) | def test_install_skill_with_target_path(self, tmp_path):
method test_install_skill_with_force (line 551) | def test_install_skill_with_force(self, tmp_path):
method test_uninstall_skill_with_target_string (line 558) | def test_uninstall_skill_with_target_string(self, tmp_path):
method test_uninstall_skill_with_target_path (line 565) | def test_uninstall_skill_with_target_path(self, tmp_path):
method test_install_with_custom_source (line 572) | def test_install_with_custom_source(self, tmp_path):
FILE: tests/test_skills_cli.py
class TestCLICommands (line 23) | class TestCLICommands:
method test_cmd_list_without_target (line 26) | def test_cmd_list_without_target(self, capsys):
method test_cmd_list_with_target (line 38) | def test_cmd_list_with_target(self, tmp_path, capsys):
method test_cmd_list_empty_target (line 51) | def test_cmd_list_empty_target(self, tmp_path, capsys):
method test_cmd_install_single_skill (line 61) | def test_cmd_install_single_skill(self, tmp_path, capsys):
method test_cmd_install_all_skills (line 73) | def test_cmd_install_all_skills(self, tmp_path, capsys):
method test_cmd_install_no_skill_or_all (line 85) | def test_cmd_install_no_skill_or_all(self, capsys):
method test_cmd_install_both_skill_and_all (line 95) | def test_cmd_install_both_skill_and_all(self, capsys):
method test_cmd_install_nonexistent_skill (line 105) | def test_cmd_install_nonexistent_skill(self, tmp_path, capsys):
method test_cmd_install_already_exists (line 116) | def test_cmd_install_already_exists(self, tmp_path, capsys):
method test_cmd_install_with_force (line 130) | def test_cmd_install_with_force(self, tmp_path, capsys):
method test_cmd_uninstall_success (line 144) | def test_cmd_uninstall_success(self, tmp_path, capsys):
method test_cmd_uninstall_not_installed (line 159) | def test_cmd_uninstall_not_installed(self, tmp_path, capsys):
method test_cmd_install_creates_target_directory (line 170) | def test_cmd_install_creates_target_directory(self, tmp_path, capsys):
method test_cmd_uninstall_invalid_target (line 187) | def test_cmd_uninstall_invalid_target(self, capsys):
class TestCLIArgumentParsing (line 198) | class TestCLIArgumentParsing:
method test_add_skills_subcommands_creates_parsers (line 201) | def test_add_skills_subcommands_creates_parsers(self):
method test_list_command_optional_target (line 222) | def test_list_command_optional_target(self):
method test_install_command_requires_target (line 236) | def test_install_command_requires_target(self):
method test_scope_choices (line 246) | def test_scope_choices(self):
method test_install_all_flag (line 263) | def test_install_all_flag(self):
method test_install_force_flag (line 273) | def test_install_force_flag(self):
method test_default_scope_is_project (line 282) | def test_default_scope_is_project(self):
FILE: tests/test_utils.py
class TestUseAdapter (line 55) | class TestUseAdapter(TrlTestCase):
method test_disables_on_none (line 56) | def test_disables_on_none(self):
method test_restores_previous_adapter (line 69) | def test_restores_previous_adapter(self):
method test_with_multiple_adapters (line 85) | def test_with_multiple_adapters(self):
class TestPad (line 107) | class TestPad(TrlTestCase):
method test_pad_1_dim_left (line 108) | def test_pad_1_dim_left(self):
method test_pad_1_dim_right (line 115) | def test_pad_1_dim_right(self):
method test_pad_2_dim_left (line 122) | def test_pad_2_dim_left(self):
method test_pad_2_dim_right (line 134) | def test_pad_2_dim_right(self):
method test_pad_2_dim_right_multidim (line 146) | def test_pad_2_dim_right_multidim(self):
method test_pad_to_multiple_of_1 (line 158) | def test_pad_to_multiple_of_1(self):
method test_pad_to_multiple_of_2 (line 166) | def test_pad_to_multiple_of_2(self):
method test_pad_to_multiple_of_side_left (line 174) | def test_pad_to_multiple_of_side_left(self):
method test_pad_to_multiple_of_no_extra_padding (line 182) | def test_pad_to_multiple_of_no_extra_padding(self):
class TestHashModule (line 191) | class TestHashModule(TrlTestCase):
method test_hash_module_deterministic_across_order (line 192) | def test_hash_module_deterministic_across_order(self):
method test_hash_module_changes_with_value (line 209) | def test_hash_module_changes_with_value(self):
method test_hash_module_includes_dtype (line 217) | def test_hash_module_includes_dtype(self):
method test_hash_module_tiny_model_twice (line 225) | def test_hash_module_tiny_model_twice(self):
method test_hash_module_tiny_model_change_layer (line 231) | def test_hash_module_tiny_model_change_layer(self):
class TestGetPEFTConfig (line 242) | class TestGetPEFTConfig(TrlTestCase):
method test_create_peft_config_use_peft_false (line 243) | def test_create_peft_config_use_peft_false(self):
method test_create_peft_config_use_peft_true (line 249) | def test_create_peft_config_use_peft_true(self):
class TestNanStd (line 275) | class TestNanStd(TrlTestCase):
method test_nanstd_ignores_nans (line 276) | def test_nanstd_ignores_nans(self):
method test_nanstd_dim_and_keepdim (line 281) | def test_nanstd_dim_and_keepdim(self):
method test_nanstd_all_nan (line 287) | def test_nanstd_all_nan(self):
class TestGenerateModelCard (line 293) | class TestGenerateModelCard(TrlTestCase):
method test_full (line 294) | def test_full(self):
method test_val_none (line 321) | def test_val_none(self):
class TestFlushLeft (line 342) | class TestFlushLeft(TrlTestCase):
method test_basic_case (line 343) | def test_basic_case(self):
method test_single_row (line 357) | def test_single_row(self):
method test_no_shift_needed (line 368) | def test_no_shift_needed(self):
method test_no_tensors (line 379) | def test_no_tensors(self):
class TestFlushRight (line 386) | class TestFlushRight(TrlTestCase):
method test_basic_case (line 387) | def test_basic_case(self):
method test_single_row (line 401) | def test_single_row(self):
method test_no_shift_needed (line 412) | def test_no_shift_needed(self):
method test_no_tensors (line 423) | def test_no_tensors(self):
class TestRepeatRandomSampler (line 430) | class TestRepeatRandomSampler(TrlTestCase):
method test_sampler (line 431) | def test_sampler(self):
method test_sampler_no_shuffle (line 443) | def test_sampler_no_shuffle(self):
method test_sampler_no_repeat (line 450) | def test_sampler_no_repeat(self):
method test_sampler_with_batch_size (line 460) | def test_sampler_with_batch_size(self):
method test_sampler_with_batch_size_and_drop (line 472) | def test_sampler_with_batch_size_and_drop(self):
method test_sampler_with_mini_repeat_count_and_batch_size_1 (line 487) | def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
method test_sampler_with_mini_repeat_count_and_batch_size_2 (line 504) | def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
method test_sampler_with_mini_repeat_count_and_batch_size_3 (line 523) | def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
class TestEntropyFromLogits (line 542) | class TestEntropyFromLogits(TrlTestCase):
method test_entropy_from_logits_2_dims (line 546) | def test_entropy_from_logits_2_dims(self, dtype, chunk_size, shape):
class TestPrintPromptCompletionsSample (line 559) | class TestPrintPromptCompletionsSample(TrlTestCase):
method test_print_output (line 561) | def test_print_output(self, mock_stdout):
method test_num_samples (line 588) | def test_num_samples(self, mock_stdout):
method test_print_messages (line 623) | def test_print_messages(self, mock_stdout):
method test_print_messages_with_tools (line 672) | def test_print_messages_with_tools(self, mock_stdout):
class TestSelectiveLogSoftmax (line 711) | class TestSelectiveLogSoftmax(TrlTestCase):
method test_selective_log_softmax (line 713) | def test_selective_log_softmax(self, dtype):
method test_selective_log_softmax_multi_index (line 733) | def test_selective_log_softmax_multi_index(self, dtype, k):
class TestShuffleSequenceDict (line 753) | class TestShuffleSequenceDict(TrlTestCase):
method test_shuffle_preserves_shape (line 754) | def test_shuffle_preserves_shape(self):
method test_shuffle_consistent_across_tensors (line 764) | def test_shuffle_consistent_across_tensors(self):
method test_none_tensor_remains_none (line 786) | def test_none_tensor_remains_none(self):
method test_shuffle_with_list (line 795) | def test_shuffle_with_list(self):
class TestSplitTensorDict (line 818) | class TestSplitTensorDict(TrlTestCase):
method test_split_equal_chunks (line 819) | def test_split_equal_chunks(self):
method test_with_none_tensor (line 833) | def test_with_none_tensor(self):
method test_with_scalar (line 845) | def test_with_scalar(self):
class TestSplitPixelValuesByGrid (line 858) | class TestSplitPixelValuesByGrid(TrlTestCase):
method test_split_correctly_0 (line 859) | def test_split_correctly_0(self):
method test_split_correctly_1 (line 875) | def test_split_correctly_1(self):
method test_missing_keys (line 891) | def test_missing_keys(self):
method test_mismatched_length (line 896) | def test_mismatched_length(self):
method test_multi_images (line 905) | def test_multi_images(self):
class TestUnsplitPixelValuesByGrid (line 922) | class TestUnsplitPixelValuesByGrid(TrlTestCase):
method test_unsplit_correctly (line 923) | def test_unsplit_correctly(self):
method test_no_op_if_not_list (line 936) | def test_no_op_if_not_list(self):
class TestForwardMaskedLogits (line 943) | class TestForwardMaskedLogits:
method test_llm (line 965) | def test_llm(self, model_id):
method test_vlm (line 1015) | def test_vlm(self, model_id):
FILE: tests/test_vllm_client_server.py
class TestChunkList (line 48) | class TestChunkList(TrlTestCase):
method test_even_split (line 49) | def test_even_split(self):
method test_uneven_split (line 52) | def test_uneven_split(self):
method test_more_chunks_than_elements (line 55) | def test_more_chunks_than_elements(self):
method test_n_equals_len (line 58) | def test_n_equals_len(self):
method test_n_is_1 (line 61) | def test_n_is_1(self):
method test_single_element_list (line 64) | def test_single_element_list(self):
method test_any_dtype (line 67) | def test_any_dtype(self):
class TestExtractLogprobs (line 74) | class TestExtractLogprobs(TrlTestCase):
method test_extract_logprobs_sorts_by_rank_and_replaces_nan (line 75) | def test_extract_logprobs_sorts_by_rank_and_replaces_nan(self):
method test_extract_logprobs_returns_none_token_ids_when_logprobs_missing (line 118) | def test_extract_logprobs_returns_none_token_ids_when_logprobs_missing...
class TestVLLMClientServer (line 130) | class TestVLLMClientServer(TrlTestCase):
method setup_class (line 134) | def setup_class(cls):
method test_generate (line 149) | def test_generate(self):
method test_generate_with_logprobs_none (line 169) | def test_generate_with_logprobs_none(self):
method test_chat (line 177) | def test_chat(self):
method test_chat_with_logprobs_none (line 197) | def test_chat_with_logprobs_none(self):
method test_chat_with_tools (line 205) | def test_chat_with_tools(self):
method test_generate_with_token_ids (line 227) | def test_generate_with_token_ids(self):
method test_generate_with_params (line 252) | def test_generate_with_params(self):
method test_update_model_params (line 272) | def test_update_model_params(self):
method test_reset_prefix_cache (line 276) | def test_reset_prefix_cache(self):
method test_logprobs_match_with_non_default_sampling (line 281) | def test_logprobs_match_with_non_default_sampling(self):
method teardown_class (line 362) | def teardown_class(cls):
class TestVLLMClientServerBaseURL (line 375) | class TestVLLMClientServerBaseURL(TrlTestCase):
method setup_class (line 379) | def setup_class(cls):
method test_generate (line 394) | def test_generate(self):
method test_generate_with_logprobs_none (line 414) | def test_generate_with_logprobs_none(self):
method test_chat (line 422) | def test_chat(self):
method test_chat_with_logprobs_none (line 442) | def test_chat_with_logprobs_none(self):
method test_chat_with_tools (line 450) | def test_chat_with_tools(self):
method test_generate_with_token_ids (line 472) | def test_generate_with_token_ids(self):
method test_generate_with_params (line 497) | def test_generate_with_params(self):
method test_update_model_params (line 517) | def test_update_model_params(self):
method test_reset_prefix_cache (line 521) | def test_reset_prefix_cache(self):
method teardown_class (line 526) | def teardown_class(cls):
class TestVLLMClientServerTP (line 538) | class TestVLLMClientServerTP(TrlTestCase):
method setup_class (line 542) | def setup_class(cls):
method test_generate (line 560) | def test_generate(self):
method test_generate_with_logprobs_none (line 580) | def test_generate_with_logprobs_none(self):
method test_chat (line 588) | def test_chat(self):
method test_chat_with_logprobs_none (line 608) | def test_chat_with_logprobs_none(self):
method test_chat_with_tools (line 616) | def test_chat_with_tools(self):
method test_generate_with_token_ids (line 638) | def test_generate_with_token_ids(self):
method test_generate_with_params (line 663) | def test_generate_with_params(self):
method test_update_model_params (line 683) | def test_update_model_params(self):
method test_reset_prefix_cache (line 687) | def test_reset_prefix_cache(self):
method teardown_class (line 692) | def teardown_class(cls):
class TestVLLMClientServerDP (line 708) | class TestVLLMClientServerDP(TrlTestCase):
method setup_class (line 712) | def setup_class(cls):
method test_generate (line 730) | def test_generate(self):
method test_generate_with_logprobs_none (line 750) | def test_generate_with_logprobs_none(self):
method test_chat (line 758) | def test_chat(self):
method test_chat_with_logprobs_none (line 778) | def test_chat_with_logprobs_none(self):
method test_chat_with_tools (line 786) | def test_chat_with_tools(self):
method test_generate_with_token_ids (line 808) | def test_generate_with_token_ids(self):
method test_generate_with_params (line 833) | def test_generate_with_params(self):
method test_update_model_params (line 853) | def test_update_model_params(self):
method test_reset_prefix_cache (line 857) | def test_reset_prefix_cache(self):
method teardown_class (line 862) | def teardown_class(cls):
class TestVLLMClientServerDeviceParameter (line 874) | class TestVLLMClientServerDeviceParameter(TrlTestCase):
method setup_class (line 880) | def setup_class(cls):
method test_init_communicator_with_device_int (line 891) | def test_init_communicator_with_device_int(self):
method test_init_communicator_with_device_string (line 908) | def test_init_communicator_with_device_string(self):
method test_init_communicator_with_torch_device (line 921) | def test_init_communicator_with_torch_device(self):
method teardown_class (line 938) | def teardown_class(cls):
class TestVLLMClientServerVLM (line 947) | class TestVLLMClientServerVLM(TrlTestCase):
method setup_class (line 951) | def setup_class(cls):
method test_generate_with_token_ids_and_image (line 960) | def test_generate_with_token_ids_and_image(self):
method test_generate_with_token_ids_mixed_images (line 1000) | def test_generate_with_token_ids_mixed_images(self):
method teardown_class (line 1035) | def teardown_class(cls):
FILE: tests/testing_utils.py
function is_bitsandbytes_multi_backend_available (line 73) | def is_bitsandbytes_multi_backend_available() -> bool:
function is_ampere_or_newer (line 88) | def is_ampere_or_newer(device_index=0):
class TrlTestCase (line 100) | class TrlTestCase:
method set_tmp_dir (line 102) | def set_tmp_dir(self, tmp_path):
function ignore_warnings (line 106) | def ignore_warnings(message: str = None, category: type[Warning] = Warni...
function kill_process (line 129) | def kill_process(process):
FILE: trl/_compat.py
function _is_package_version_below (line 30) | def _is_package_version_below(package_name: str, version_threshold: str)...
function _is_package_version_at_least (line 54) | def _is_package_version_at_least(package_name: str, version_threshold: s...
function _patch_vllm_logging (line 78) | def _patch_vllm_logging() -> None:
function _patch_vllm_disabled_tqdm (line 86) | def _patch_vllm_disabled_tqdm() -> None:
function _patch_vllm_cached_tokenizer (line 110) | def _patch_vllm_cached_tokenizer() -> None:
function _patch_transformers_hybrid_cache (line 170) | def _patch_transformers_hybrid_cache() -> None:
function _patch_transformers_parallelism_config (line 214) | def _patch_transformers_parallelism_config() -> None:
FILE: trl/_lazy_module.py
class _LazyModule (line 22) | class _LazyModule(ModuleType):
method __init__ (line 29) | def __init__(self, name, module_file, import_structure, module_spec=No...
method __dir__ (line 46) | def __dir__(self):
method __getattr__ (line 55) | def __getattr__(self, name: str) -> Any:
method _get_module (line 69) | def _get_module(self, module_name: str):
method __reduce__ (line 78) | def __reduce__(self):
FILE: trl/chat_template_utils.py
function clone_chat_template (line 18) | def clone_chat_template(
function add_response_schema (line 429) | def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTok...
function is_chat_template_prefix_preserving (line 472) | def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -...
function get_training_chat_template (line 610) | def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | ...
function _validate_tool_calls (line 671) | def _validate_tool_calls(tool_calls: list | None) -> None:
function parse_response (line 709) | def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict:
FILE: trl/cli/accelerate_config.py
function resolve_accelerate_config_argument (line 19) | def resolve_accelerate_config_argument(launch_args: list[str]) -> list[s...
FILE: trl/cli/accelerate_launcher.py
function launch_training_script (line 22) | def launch_training_script(
FILE: trl/cli/commands/__init__.py
function get_commands (line 22) | def get_commands() -> list[Command]:
FILE: trl/cli/commands/base.py
class CommandContext (line 21) | class CommandContext:
method argv_after (line 26) | def argv_after(self, token: str) -> list[str]:
class Command (line 41) | class Command(ABC):
method __init__ (line 52) | def __init__(self, name: str, help_text: str):
method register (line 57) | def register(self, subparsers) -> None:
method run (line 61) | def run(self, args: Namespace, context: CommandContext) -> int:
FILE: trl/cli/commands/env.py
class EnvCommand (line 20) | class EnvCommand(Command):
method __init__ (line 23) | def __init__(self):
method register (line 26) | def register(self, subparsers) -> None:
method run (line 29) | def run(self, args: Namespace, context: CommandContext) -> int:
FILE: trl/cli/commands/skills.py
class SkillsCommand (line 21) | class SkillsCommand(Command):
method __init__ (line 24) | def __init__(self):
method register (line 28) | def register(self, subparsers) -> None:
method run (line 33) | def run(self, args: Namespace, context: CommandContext) -> int:
FILE: trl/cli/commands/training.py
function _subtract_subsequence (line 21) | def _subtract_subsequence(lst: list[str], subseq: list[str]) -> list[str]:
class TrainingCommand (line 34) | class TrainingCommand(Command):
method __init__ (line 45) | def __init__(self, name: str):
method register (line 48) | def register(self, subparsers) -> None:
method run (line 51) | def run(self, args: Namespace, context: CommandContext) -> int:
FILE: trl/cli/commands/vllm_serve.py
class VllmServeCommand (line 20) | class VllmServeCommand(Command):
method __init__ (line 23) | def __init__(self):
method register (line 26) | def register(self, subparsers) -> None:
method run (line 29) | def run(self, args: Namespace, context: CommandContext) -> int:
FILE: trl/cli/main.py
function _build_parser (line 22) | def _build_parser(commands: list[Command]) -> ArgumentParser:
function main (line 32) | def main(argv: list[str] | None = None) -> int:
FILE: trl/data_utils.py
function prepare_multimodal_messages (line 32) | def prepare_multimodal_messages(messages: list[dict[str, Any]], images: ...
function prepare_multimodal_messages_vllm (line 126) | def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> ...
function is_conversational (line 159) | def is_conversational(example: dict[str, Any]) -> bool:
function apply_chat_template (line 200) | def apply_chat_template(
function maybe_apply_chat_template (line 333) | def maybe_apply_chat_template(
function _unpair_row (line 397) | def _unpair_row(examples: list[dict[str, list[dict[str, str]]]]) -> list...
function unpair_preference_dataset (line 408) | def unpair_preference_dataset(
function maybe_unpair_preference_dataset (line 451) | def maybe_unpair_preference_dataset(
function extract_prompt (line 502) | def extract_prompt(example: dict[str, Sequence]) -> dict[str, Sequence]:
function maybe_extract_prompt (line 589) | def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]:
function _get_dataset_format (line 615) | def _get_dataset_format(dataset: DatasetType) -> dict[str, Any]:
function _check_if_columns_can_be_packed (line 627) | def _check_if_columns_can_be_packed(columns: list[pa.Array]):
class _SegmentTree (line 639) | class _SegmentTree:
method __init__ (line 647) | def __init__(self, maxval: int):
method add (line 653) | def add(self, val):
method remove (line 663) | def remove(self, val):
method search (line 673) | def search(self, val):
function _pack_bfd (line 684) | def _pack_bfd(
function _pack_wrapped (line 774) | def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table:
function pack_dataset (line 790) | def pack_dataset(
function truncate_dataset (line 880) | def truncate_dataset(
function is_conversational_from_value (line 944) | def is_conversational_from_value(example: dict[str, Any]) -> bool:
function maybe_convert_to_chatml (line 984) | def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]:
FILE: trl/experimental/async_grpo/async_grpo_config.py
class AsyncGRPOConfig (line 21) | class AsyncGRPOConfig(_BaseConfig):
method __post_init__ (line 201) | def __post_init__(self):
FILE: trl/experimental/async_grpo/async_grpo_trainer.py
class _SupportsReset (line 47) | class _SupportsReset(Protocol):
method reset (line 48) | def reset(self, **kwargs) -> str | None: ...
class RolloutWorkerProtocol (line 54) | class RolloutWorkerProtocol(Protocol):
method start (line 57) | def start(self) -> None: ...
method stop (line 58) | def stop(self) -> None: ...
method pause (line 59) | def pause(self) -> None: ...
method resume (line 60) | def resume(self) -> None: ...
method send_weights (line 61) | def send_weights(self, iterator: Iterator[tuple[str, torch.Tensor]]) -...
method update_model_version (line 62) | def update_model_version(self, version: int) -> None: ...
class StepIntervalCallback (line 65) | class StepIntervalCallback(TrainerCallback):
method __init__ (line 70) | def __init__(self, fn, every_n_steps: int):
method on_step_end (line 74) | def on_step_end(self, _args, state, _control, **_kwargs):
class RolloutQueueDataset (line 79) | class RolloutQueueDataset(torch.utils.data.IterableDataset):
method __init__ (line 80) | def __init__(self, rollout_queue, model_version_fn, max_staleness=3, t...
method __iter__ (line 86) | def __iter__(self):
class _EmptyIterableDataset (line 115) | class _EmptyIterableDataset(torch.utils.data.IterableDataset):
method __iter__ (line 118) | def __iter__(self):
class DataCollatorForRollout (line 123) | class DataCollatorForRollout(DataCollatorMixin):
method torch_call (line 127) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
class AsyncGRPOTrainer (line 168) | class AsyncGRPOTrainer(_BaseTrainer):
method __init__ (line 270) | def __init__(
method get_train_dataloader (line 391) | def get_train_dataloader(self) -> DataLoader:
method _set_signature_columns_if_needed (line 414) | def _set_signature_columns_if_needed(self):
method compute_loss (line 430) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
method log (line 543) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method _streaming_iter (line 552) | def _streaming_iter(self):
method _sync_weight (line 560) | def _sync_weight(self):
method _inner_training_loop (line 591) | def _inner_training_loop(self, *args, **kwargs):
FILE: trl/experimental/async_grpo/async_rollout_worker.py
class RolloutGroup (line 47) | class RolloutGroup:
class RolloutSample (line 64) | class RolloutSample:
class AsyncRolloutWorker (line 75) | class AsyncRolloutWorker:
method __init__ (line 83) | def __init__(
method _wait_for_server_ready_sync (line 178) | def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_i...
method _init_weight_transfer (line 201) | def _init_weight_transfer(self) -> None:
method update_model_version (line 231) | def update_model_version(self, model_version: int):
method _run_loops (line 234) | async def _run_loops(self, stop_event: asyncio.Event) -> None:
method start (line 245) | def start(self) -> None:
method stop (line 249) | def stop(self) -> None:
method _run (line 257) | def _run(self) -> None:
method pause (line 270) | def pause(self) -> None:
method resume (line 275) | def resume(self) -> None:
method send_weights (line 280) | def send_weights(self, iterator) -> None:
method _generate_loop (line 303) | async def _generate_loop(self, stop_event: asyncio.Event) -> None:
method _compute_rollout_metrics (line 422) | def _compute_rollout_metrics(self, samples: list[RolloutSample], scori...
method _score_loop (line 442) | async def _score_loop(self, stop_event: asyncio.Event) -> None:
method _repeat_iterator (line 497) | def _repeat_iterator(self) -> Iterator[tuple[int, dict[str, Any]]]:
method _generate_one (line 509) | async def _generate_one(
method _build_messages_suffix_ids (line 547) | def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -...
method _execute_tool_calls (line 572) | def _execute_tool_calls(
method _generate_one_turn (line 591) | async def _generate_one_turn(self, prompt_ids: list[int]) -> tuple[lis...
method _score_group (line 614) | async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]:
method _post (line 692) | async def _post(self, path: str, payload: dict, timeout: float, max_re...
FILE: trl/experimental/bco/bco_config.py
class BCOConfig (line 22) | class BCOConfig(_BaseConfig):
FILE: trl/experimental/bco/bco_trainer.py
function get_global_statistics (line 89) | def get_global_statistics(
class RunningMoments (line 110) | class RunningMoments:
method update (line 123) | def update(self, xs: torch.Tensor) -> tuple[float, float]:
method save_to_json (line 150) | def save_to_json(self, json_path: str):
method load_from_json (line 160) | def load_from_json(cls, accelerator: Accelerator, json_path: str):
function _tokenize (line 168) | def _tokenize(
function _process_tokens (line 239) | def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = ...
class BCOTrainer (line 349) | class BCOTrainer(_BaseTrainer):
method __init__ (line 410) | def __init__(
method match_underlying_distribution (line 803) | def match_underlying_distribution(self):
method _get_chosen_prob (line 806) | def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> to...
method _vectorize_prompt (line 835) | def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mas...
method _get_prompt_embeddings (line 853) | def _get_prompt_embeddings(
method _get_sample_prompt_embeddings (line 875) | def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size:...
method _save_optimizer_and_scheduler (line 907) | def _save_optimizer_and_scheduler(self, output_dir):
method _load_optimizer_and_scheduler (line 918) | def _load_optimizer_and_scheduler(self, checkpoint):
method null_ref_context (line 936) | def null_ref_context(self):
method get_train_dataloader (line 949) | def get_train_dataloader(self) -> DataLoader:
method get_eval_dataloader (line 983) | def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> ...
method compute_reference_log_probs (line 1029) | def compute_reference_log_probs(self, padded_batch: dict) -> dict:
method get_batch_logps (line 1072) | def get_batch_logps(
method forward (line 1119) | def forward(
method _get_udm_weight (line 1167) | def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> t...
method bco_loss (line 1176) | def bco_loss(
method get_batch_loss_metrics (line 1232) | def get_batch_loss_metrics(
method compute_loss (line 1326) | def compute_loss(
method store_metrics (line 1350) | def store_metrics(self, metrics: dict[str, float], train_eval: Literal...
method _get_train_sampler (line 1354) | def _get_train_sampler(self, dataset: Dataset | None = None) -> torch....
method generate_from_model_and_ref (line 1361) | def generate_from_model_and_ref(self, model, batch: dict[str, torch.Lo...
method prediction_step (line 1408) | def prediction_step(
method evaluation_loop (line 1446) | def evaluation_loop(
method log (line 1506) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method _save_checkpoint (line 1542) | def _save_checkpoint(self, model, trial):
FILE: trl/experimental/bema_for_ref_model/callback.py
class CallbackHandlerWithRefModel (line 28) | class CallbackHandlerWithRefModel(CallbackHandler):
method __init__ (line 33) | def __init__(self, callbacks, model, ref_model, processing_class, opti...
method call_event (line 38) | def call_event(self, event, args, state, control, **kwargs):
class BEMACallback (line 59) | class BEMACallback(_BEMACallback):
method __init__ (line 128) | def __init__(
method on_step_end (line 158) | def on_step_end(
method _update_model_with_bema_weights (line 202) | def _update_model_with_bema_weights(self, model, bema_state_dict, is_p...
FILE: trl/experimental/bema_for_ref_model/dpo_trainer.py
class DPOTrainer (line 19) | class DPOTrainer(_DPOTrainer):
method __init__ (line 20) | def __init__(self, *args, **kwargs):
FILE: trl/experimental/cpo/cpo_config.py
class CPOConfig (line 22) | class CPOConfig(_BaseConfig):
method __post_init__ (line 176) | def __post_init__(self):
FILE: trl/experimental/cpo/cpo_trainer.py
class CPOTrainer (line 74) | class CPOTrainer(_BaseTrainer):
method __init__ (line 129) | def __init__(
method build_tokenized_answer (line 389) | def build_tokenized_answer(self, prompt, answer):
method tokenize_row (line 438) | def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | N...
method concatenated_inputs (line 571) | def concatenated_inputs(
method cpo_loss (line 635) | def cpo_loss(
method get_batch_logps (line 709) | def get_batch_logps(
method concatenated_forward (line 749) | def concatenated_forward(
method get_batch_loss_metrics (line 822) | def get_batch_loss_metrics(
method compute_loss (line 876) | def compute_loss(
method generate_from_model (line 897) | def generate_from_model(self, model, batch: dict[str, torch.LongTensor...
method prediction_step (line 920) | def prediction_step(
method store_metrics (line 957) | def store_metrics(self, metrics: dict[str, float], train_eval: Literal...
method evaluation_loop (line 961) | def evaluation_loop(
method log (line 1012) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method _shift_right (line 1030) | def _shift_right(self, input_ids):
method _save_checkpoint (line 1054) | def _save_checkpoint(self, model, trial):
FILE: trl/experimental/dppo/dppo_config.py
class DPPOConfig (line 22) | class DPPOConfig(GRPOConfig):
method __post_init__ (line 93) | def __post_init__(self):
FILE: trl/experimental/dppo/dppo_trainer.py
function _strip_padding (line 66) | def _strip_padding(tensor: torch.Tensor, mask: torch.Tensor) -> list[list]:
class DPPOTrainer (line 71) | class DPPOTrainer(GRPOTrainer):
method __init__ (line 190) | def __init__(
method _tokenize_prompts (line 234) | def _tokenize_prompts(self, prompts: list):
method _generate_single_turn (line 274) | def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
method _tool_call_loop (line 418) | def _tool_call_loop(
method _generate (line 619) | def _generate(self, prompts: list):
method _get_per_token_logps_with_topk (line 752) | def _get_per_token_logps_with_topk(
method _generate_and_score_completions (line 841) | def _generate_and_score_completions(
method _compute_divergence_mask (line 1187) | def _compute_divergence_mask(
method _compute_loss (line 1270) | def _compute_loss(self, model, inputs):
FILE: trl/experimental/gfpo/gfpo_config.py
class GFPOConfig (line 21) | class GFPOConfig(_GRPOConfig):
method __post_init__ (line 29) | def __post_init__(self):
FILE: trl/experimental/gfpo/gfpo_trainer.py
class GFPOTrainer (line 33) | class GFPOTrainer(_GRPOTrainer):
method __init__ (line 34) | def __init__(
method _generate_and_score_completions (line 69) | def _generate_and_score_completions(self, inputs):
FILE: trl/experimental/gkd/gkd_config.py
class GKDConfig (line 22) | class GKDConfig(SFTConfig):
method __post_init__ (line 104) | def __post_init__(self):
FILE: trl/experimental/gkd/gkd_trainer.py
class GKDTrainer (line 53) | class GKDTrainer(SFTTrainer):
method __init__ (line 110) | def __init__(
method generalized_jsd_loss (line 222) | def generalized_jsd_loss(
method compute_loss (line 292) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
method generate_on_policy_outputs (line 394) | def generate_on_policy_outputs(model, inputs, generation_config, pad_t...
method training_step (line 416) | def training_step(
FILE: trl/experimental/gold/gold_config.py
class GOLDConfig (line 23) | class GOLDConfig(SFTConfig):
method __post_init__ (line 385) | def __post_init__(self):
FILE: trl/experimental/gold/gold_trainer.py
function print_prompt_completions_sample_uld (line 94) | def print_prompt_completions_sample_uld(
function build_teacher_inputs_from_texts (line 173) | def build_teacher_inputs_from_texts(
class ULDLoss (line 236) | class ULDLoss(nn.Module):
method __init__ (line 241) | def __init__(self, config: GOLDConfig, student_tokenizer=None, teacher...
method __call__ (line 270) | def __call__(
method _initialize_vocabulary_mapping (line 304) | def _initialize_vocabulary_mapping(self):
method _compute_distillation_loss (line 336) | def _compute_distillation_loss(
method _build_alignment_groups_from_ids (line 438) | def _build_alignment_groups_from_ids(self, student_token_ids, teacher_...
method _merge_probabilities_with_alignment_groups (line 534) | def _merge_probabilities_with_alignment_groups(self, probs, alignment_...
method _compute_hybrid_uld_loss (line 611) | def _compute_hybrid_uld_loss(self, student_aligned, teacher_aligned):
method _compute_jsd_loss_for_matched_tokens (line 708) | def _compute_jsd_loss_for_matched_tokens(self, student_logits, teacher...
method _get_start_and_size_answers (line 737) | def _get_start_and_size_answers(self, answer_tensors):
class GOLDVLLMSyncCallback (line 754) | class GOLDVLLMSyncCallback(TrainerCallback):
method __init__ (line 757) | def __init__(self, trainer):
method on_step_end (line 760) | def on_step_end(self, args, state: TrainerState, control: TrainerContr...
class GOLDTrainer (line 774) | class GOLDTrainer(SFTTrainer):
method __init__ (line 789) | def __init__(
method _set_signature_columns_if_needed (line 1047) | def _set_signature_columns_if_needed(self):
method _get_train_sampler (line 1065) | def _get_train_sampler(self, dataset=None):
method get_train_dataloader (line 1077) | def get_train_dataloader(self):
method _prepare_inputs (line 1118) | def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | A...
method _decode_completion_texts_from_labels (line 1131) | def _decode_completion_texts_from_labels(self, slice_inputs: dict[str,...
method _ensure_original_text_fields (line 1151) | def _ensure_original_text_fields(
method _build_sequence_batch (line 1177) | def _build_sequence_batch(
method _fill_buffer (line 1197) | def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any]...
method _generate_on_policy_for_slices (line 1231) | def _generate_on_policy_for_slices(
method _deduplicate_prompts (line 1303) | def _deduplicate_prompts(
method _generate_vllm_server_global (line 1323) | def _generate_vllm_server_global(
method _generate_vllm_colocate (line 1378) | def _generate_vllm_colocate(
method _generate_non_vllm_for_slices (line 1445) | def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.T...
method _process_completions_to_buffer (line 1472) | def _process_completions_to_buffer(
method _prepare_dataset (line 1564) | def _prepare_dataset(
method _prepare_dataset_with_original_text (line 1584) | def _prepare_dataset_with_original_text(
method generalized_jsd_loss (line 1817) | def generalized_jsd_loss(
method compute_loss (line 1896) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
method generate_on_policy_outputs (line 2082) | def generate_on_policy_outputs(self, model, inputs, generation_config,...
method _sync_fsdp_params_to_vllm (line 2163) | def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "...
method _move_model_to_vllm (line 2190) | def _move_model_to_vllm(self):
method _wake_vllm_if_needed (line 2260) | def _wake_vllm_if_needed(self):
method _get_liger_zero3_lm_head_gather_ctx (line 2265) | def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module):
method training_step (line 2287) | def training_step(
method log (line 2325) | def log(self, logs: dict[str, float], start_time: float | None = None)...
FILE: trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py
class GRPOWithReplayBufferConfig (line 21) | class GRPOWithReplayBufferConfig(GRPOConfig):
FILE: trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py
class ReplayBuffer (line 28) | class ReplayBuffer:
method __init__ (line 33) | def __init__(self, max_size: int):
method add (line 37) | def add(self, scores: list[float], data: list[dict]):
method sample (line 46) | def sample(self, num_samples: int) -> list[dict[str, torch.Tensor]]:
class GRPOWithReplayBufferTrainer (line 60) | class GRPOWithReplayBufferTrainer(GRPOTrainer):
method __init__ (line 61) | def __init__(self, args: GRPOWithReplayBufferConfig | None = None, **k...
method _generate_and_score_completions (line 65) | def _generate_and_score_completions(
method slice_group_data (line 427) | def slice_group_data(
method update_replay_buffer (line 449) | def update_replay_buffer(
method sample_from_replay_buffer (line 528) | def sample_from_replay_buffer(
method update_with_replay_buffer (line 574) | def update_with_replay_buffer(
FILE: trl/experimental/gspo_token/grpo_trainer.py
class GRPOTrainer (line 21) | class GRPOTrainer(_GRPOTrainer):
method _compute_loss (line 22) | def _compute_loss(self, model, inputs):
FILE: trl/experimental/judges/judges.py
function _ensure_llm_blender_importable (line 57) | def _ensure_llm_blender_importable() -> None:
class BaseJudge (line 77) | class BaseJudge(ABC):
method judge (line 83) | def judge(self, prompts: list[str], completions: list[str], shuffle_or...
class BaseRankJudge (line 87) | class BaseRankJudge(ABC):
method judge (line 107) | def judge(self, prompts: list[str], completions: list[list[str]], shuf...
class BasePairwiseJudge (line 128) | class BasePairwiseJudge(BaseJudge):
method judge (line 134) | def judge(self, prompts: list[str], completions: list[list[str]], shuf...
class BaseBinaryJudge (line 160) | class BaseBinaryJudge(BaseJudge):
method judge (line 166) | def judge(
class PairRMJudge (line 200) | class PairRMJudge(BasePairwiseJudge):
method __init__ (line 227) | def __init__(self):
method judge (line 242) | def judge(
class HfPairwiseJudge (line 309) | class HfPairwiseJudge(BasePairwiseJudge):
method __init__ (line 327) | def __init__(
method judge (line 336) | def judge(self, prompts: list[str], completions: list[list[str]], shuf...
class OpenAIPairwiseJudge (line 365) | class OpenAIPairwiseJudge(BasePairwiseJudge):
method __init__ (line 383) | def __init__(
method judge (line 397) | def judge(self, prompts: list[str], completions: list[list[str]], shuf...
class AllTrueJudge (line 440) | class AllTrueJudge(BaseBinaryJudge):
method __init__ (line 454) | def __init__(self, judges: list[BaseBinaryJudge]):
method judge (line 457) | def judge(
FILE: trl/experimental/kto/kto_config.py
class KTOConfig (line 22) | class KTOConfig(_BaseConfig):
method __post_init__ (line 145) | def __post_init__(self):
FILE: trl/experimental/kto/kto_trainer.py
function _get_kl_dataset (line 85) | def _get_kl_dataset(batch: dict[str, list[Any]]) -> dict[str, list[Any]]:
function _tokenize (line 96) | def _tokenize(
function _process_tokens (line 156) | def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = ...
class KTOTrainer (line 248) | class KTOTrainer(_BaseTrainer):
method __init__ (line 309) | def __init__(
method null_ref_context (line 743) | def null_ref_context(self):
method get_train_dataloader (line 756) | def get_train_dataloader(self) -> DataLoader:
method get_eval_dataloader (line 800) | def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> ...
method compute_reference_log_probs (line 855) | def compute_reference_log_probs(self, padded_batch: dict) -> dict:
method get_batch_logps (line 899) | def get_batch_logps(
method forward (line 939) | def forward(
method kto_loss (line 985) | def kto_loss(
method _compute_kl_logps (line 1062) | def _compute_kl_logps(self, model, batch):
method _compute_loss_liger (line 1081) | def _compute_loss_liger(self, model, batch):
method get_batch_loss_metrics (line 1174) | def get_batch_loss_metrics(
method compute_loss (line 1290) | def compute_loss(
method store_metrics (line 1314) | def store_metrics(self, metrics: dict[str, float], train_eval: Literal...
method _get_train_sampler (line 1318) | def _get_train_sampler(self, dataset: Dataset | None = None) -> torch....
method generate_from_model_and_ref (line 1325) | def generate_from_model_and_ref(self, model, batch: dict[str, torch.Lo...
method prediction_step (line 1373) | def prediction_step(
method evaluation_loop (line 1411) | def evaluation_loop(
method log (line 1471) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method _save_checkpoint (line 1507) | def _save_checkpoint(self, model, trial):
FILE: trl/experimental/merge_model_callback.py
function upload_model_to_hf (line 35) | def upload_model_to_hf(folder_path: str, repo_id: str):
class MergeConfig (line 48) | class MergeConfig:
method __init__ (line 82) | def __init__(self, method: str = "linear"):
method create_merge_config_linear (line 114) | def create_merge_config_linear(self) -> "MergeConfiguration":
method create_merge_config_ties (line 133) | def create_merge_config_ties(self) -> "MergeConfiguration":
method create_merge_config_dare_ties (line 177) | def create_merge_config_dare_ties(self) -> "MergeConfiguration":
method create_merge_config_slerp (line 221) | def create_merge_config_slerp(self) -> "MergeConfiguration":
method create (line 260) | def create(self) -> "MergeConfiguration":
function merge_models (line 271) | def merge_models(config: "MergeConfiguration", out_path: str):
class MergeModelCallback (line 294) | class MergeModelCallback(TrainerCallback):
method __init__ (line 319) | def __init__(
method _merge_and_maybe_push (line 333) | def _merge_and_maybe_push(self, output_dir, global_step, model):
method on_save (line 346) | def on_save(self, args, state, control, model=None, **kwargs):
method on_train_end (line 350) | def on_train_end(self, args, state, control, model=None, **kwargs):
FILE: trl/experimental/minillm/minillm_config.py
class MiniLLMConfig (line 24) | class MiniLLMConfig(GRPOConfig):
method __post_init__ (line 87) | def __post_init__(self):
FILE: trl/experimental/minillm/minillm_trainer.py
function dummy_reward_func (line 43) | def dummy_reward_func(completions: list, **kwargs):
class MiniLLMTrainer (line 48) | class MiniLLMTrainer(GRPOTrainer):
method __init__ (line 166) | def __init__(
method _single_step_decomposition_loss (line 245) | def _single_step_decomposition_loss(
method _compute_advantage (line 292) | def _compute_advantage(
method compute_loss (line 349) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
FILE: trl/experimental/nash_md/nash_md_config.py
class NashMDConfig (line 21) | class NashMDConfig(OnlineDPOConfig):
method __post_init__ (line 43) | def __post_init__(self):
FILE: trl/experimental/nash_md/nash_md_trainer.py
class GeometricMixtureWrapper (line 50) | class GeometricMixtureWrapper(GenerationMixin):
method __init__ (line 66) | def __init__(self, model, ref_model, generation_config, mixture_coef=0...
method __call__ (line 78) | def __call__(self, *args, **kwargs):
method forward (line 82) | def forward(self, *args, **kwargs):
method prepare_inputs_for_generation (line 93) | def prepare_inputs_for_generation(self, *args, **kwargs):
method _validate_model_class (line 101) | def _validate_model_class(self):
method _validate_model_kwargs (line 104) | def _validate_model_kwargs(self, model_kwargs):
class NashMDTrainer (line 108) | class NashMDTrainer(OnlineDPOTrainer):
method __init__ (line 170) | def __init__(
method mixture_coef (line 236) | def mixture_coef(self):
method _generate_completions (line 243) | def _generate_completions(self, model, prompts):
method _process_completions (line 298) | def _process_completions(self, model_output, mixture_output, prompts):
method _compute_rewards (line 325) | def _compute_rewards(self, model_data, mixture_data, context_length):
method _compute_judge (line 343) | def _compute_judge(self, model_data, mixture_data, context_length):
method _compute_logprobs (line 377) | def _compute_logprobs(self, model, model_data, context_length):
method _compute_losses (line 402) | def _compute_losses(
method _log_statistics (line 422) | def _log_statistics(
method training_step (line 481) | def training_step(
FILE: trl/experimental/online_dpo/online_dpo_config.py
class OnlineDPOConfig (line 23) | class OnlineDPOConfig(_BaseConfig):
method __post_init__ (line 386) | def __post_init__(self):
FILE: trl/experimental/online_dpo/online_dpo_trainer.py
class OnlineDPOTrainer (line 104) | class OnlineDPOTrainer(_BaseTrainer):
method __init__ (line 182) | def __init__(
method beta (line 587) | def beta(self):
method tokenize_row (line 595) | def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrai...
method _enable_gradient_checkpointing (line 610) | def _enable_gradient_checkpointing(self, model: PreTrainedModel, args:...
method _generate_vllm (line 625) | def _generate_vllm(self, prompts, images=None):
method _generate_vllm_server (line 655) | def _generate_vllm_server(self, prompts, images=None):
method _generate_vllm_colocate (line 731) | def _generate_vllm_colocate(self, prompts, images=None):
method _sync_fsdp2_params_to_vllm (line 772) | def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
method _move_model_to_vllm (line 795) | def _move_model_to_vllm(self):
method _sync_fsdp1_params_to_vllm (line 871) | def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = ...
method _fix_param_name_to_vllm (line 898) | def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | No...
method process_vision_row (line 906) | def process_vision_row(
method _generate (line 933) | def _generate(self, model, prompts, images=None):
method _calculate_rewards_from_functions (line 1088) | def _calculate_rewards_from_functions(self, prompts, completions, comp...
method _forward (line 1134) | def _forward(self, model, prompt_ids, prompt_mask, completion_ids, com...
method training_step (line 1172) | def training_step(
method _maybe_log_save_evaluate (line 1394) | def _maybe_log_save_evaluate(
method _save_checkpoint (line 1440) | def _save_checkpoint(self, model, trial):
FILE: trl/experimental/openenv/utils.py
function _build_base_generation_kwargs (line 29) | def _build_base_generation_kwargs(
function _build_colocate_sampling_params (line 60) | def _build_colocate_sampling_params(
function _build_server_generation_kwargs (line 80) | def _build_server_generation_kwargs(
function generate_rollout_completions (line 88) | def generate_rollout_completions(
function _generate_rollout_completions_server (line 116) | def _generate_rollout_completions_server(
function _generate_rollout_completions_colocate (line 156) | def _generate_rollout_completions_colocate(
FILE: trl/experimental/orpo/orpo_config.py
class ORPOConfig (line 22) | class ORPOConfig(_BaseConfig):
FILE: trl/experimental/orpo/orpo_trainer.py
function log1mexp (line 78) | def log1mexp(x: torch.FloatTensor) -> torch.FloatTensor:
class ORPOTrainer (line 85) | class ORPOTrainer(_BaseTrainer):
method __init__ (line 138) | def __init__(
method build_tokenized_answer (line 374) | def build_tokenized_answer(self, prompt, answer):
method tokenize_row (line 423) | def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | N...
method concatenated_inputs (line 566) | def concatenated_inputs(
method odds_ratio_loss (line 630) | def odds_ratio_loss(
method get_batch_logps (line 665) | def get_batch_logps(
method concatenated_forward (line 705) | def concatenated_forward(
method get_batch_loss_metrics (line 784) | def get_batch_loss_metrics(
method compute_loss (line 839) | def compute_loss(
method generate_from_model (line 863) | def generate_from_model(self, model, batch: dict[str, torch.LongTensor...
method prediction_step (line 886) | def prediction_step(
method store_metrics (line 928) | def store_metrics(self, metrics: dict[str, float], train_eval: Literal...
method evaluation_loop (line 932) | def evaluation_loop(
method log (line 983) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method _shift_right (line 1001) | def _shift_right(self, input_ids):
method _save_checkpoint (line 1025) | def _save_checkpoint(self, model, trial):
FILE: trl/experimental/papo/papo_config.py
class PAPOConfig (line 22) | class PAPOConfig(GRPOConfig):
method __post_init__ (line 60) | def __post_init__(self):
FILE: trl/experimental/papo/papo_trainer.py
class PAPOTrainer (line 27) | class PAPOTrainer(GRPOTrainer):
method __init__ (line 114) | def __init__(
method _mask_image (line 154) | def _mask_image(self, pixel_values: torch.Tensor, mask_ratio: float = ...
method _compute_loss (line 214) | def _compute_loss(self, model, inputs):
FILE: trl/experimental/ppo/modeling_value_head.py
class PreTrainedModelWrapper (line 52) | class PreTrainedModelWrapper(nn.Module):
method __init__ (line 79) | def __init__(
method from_pretrained (line 107) | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, *...
method _get_checkpoint_from_hub (line 339) | def _get_checkpoint_from_hub(
method _get_current_device (line 391) | def _get_current_device(cls):
method _split_kwargs (line 409) | def _split_kwargs(cls, kwargs):
method add_and_load_reward_modeling_adapter (line 439) | def add_and_load_reward_modeling_adapter(
method push_to_hub (line 509) | def push_to_hub(self, *args, **kwargs):
method save_pretrained (line 523) | def save_pretrained(self, *args, **kwargs):
method state_dict (line 550) | def state_dict(self, *args, **kwargs):
method post_init (line 556) | def post_init(self, *args, **kwargs):
method compute_reward_score (line 563) | def compute_reward_score(self, input_ids, attention_mask=None, **kwargs):
class ValueHead (line 594) | class ValueHead(nn.Module):
method __init__ (line 599) | def __init__(self, config, **kwargs):
method forward (line 622) | def forward(self, hidden_states):
class AutoModelForCausalLMWithValueHead (line 634) | class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
method __init__ (line 665) | def __init__(self, pretrained_model, **kwargs):
method _init_weights (line 681) | def _init_weights(self, **kwargs):
method forward (line 703) | def forward(
method generate (line 758) | def generate(self, *args, **kwargs):
method state_dict (line 772) | def state_dict(self, *args, **kwargs):
method push_to_hub (line 788) | def push_to_hub(self, *args, **kwargs):
method post_init (line 793) | def post_init(self, state_dict):
class AutoModelForSeq2SeqLMWithValueHead (line 838) | class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
method __init__ (line 861) | def __init__(self, pretrained_model, **kwargs):
method _has_lm_head (line 873) | def _has_lm_head(self):
method post_init (line 880) | def post_init(self, state_dict):
method state_dict (line 934) | def state_dict(self, *args, **kwargs):
method push_to_hub (line 950) | def push_to_hub(self, *args, **kwargs):
method _init_weights (line 955) | def _init_weights(self, **kwargs):
method forward (line 969) | def forward(
method generate (line 1003) | def generate(self, *args, **kwargs):
FILE: trl/experimental/ppo/ppo_config.py
class PPOConfig (line 22) | class PPOConfig(_BaseConfig):
FILE: trl/experimental/ppo/ppo_trainer.py
function generate (line 84) | def generate(
function batch_generation (line 124) | def batch_generation(
function exact_div (line 156) | def exact_div(a, b, custom_error_message=""):
function print_rich_table (line 163) | def print_rich_table(df: pd.DataFrame) -> None:
function truncate_response (line 177) | def truncate_response(stop_token_id: int, pad_token_id: int, responses: ...
function forward (line 200) | def forward(
class OnlineTrainerState (line 233) | class OnlineTrainerState(TrainerState):
function masked_mean (line 246) | def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | N...
function masked_var (line 254) | def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool ...
function masked_whiten (line 273) | def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: ...
class PolicyAndValueWrapper (line 284) | class PolicyAndValueWrapper(nn.Module):
method __init__ (line 285) | def __init__(self, policy, value_model) -> None:
method gradient_checkpointing_enable (line 292) | def gradient_checkpointing_enable(self, **kwargs):
method gradient_checkpointing_disable (line 296) | def gradient_checkpointing_disable(self):
method forward (line 300) | def forward(self, **kwargs):
class PPOTrainer (line 306) | class PPOTrainer(_BaseTrainer):
method __init__ (line 358) | def __init__(
method get_train_dataloader (line 577) | def get_train_dataloader(self) -> DataLoader:
method get_eval_dataloader (line 580) | def get_eval_dataloader(self) -> DataLoader:
method null_ref_context (line 584) | def null_ref_context(self):
method save_model (line 597) | def save_model(self, output_dir: str | None = None, _internal_call: bo...
method train (line 611) | def train(self):
method generate_completions (line 957) | def generate_completions(self, sampling: bool = False):
method _save_checkpoint (line 1030) | def _save_checkpoint(self, model, trial):
FILE: trl/experimental/prm/prm_config.py
class PRMConfig (line 21) | class PRMConfig(_BaseConfig):
FILE: trl/experimental/prm/prm_trainer.py
function compute_accuracy (line 52) | def compute_accuracy(eval_pred: EvalPrediction) -> dict[str, float]:
class PRMTrainer (line 97) | class PRMTrainer(_BaseTrainer):
method __init__ (line 150) | def __init__(
method tokenize_row (line 259) | def tokenize_row(
method _save_checkpoint (line 351) | def _save_checkpoint(self, model, trial):
FILE: trl/experimental/utils.py
class DPODataCollatorWithPadding (line 46) | class DPODataCollatorWithPadding:
method __call__ (line 60) | def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
class DataCollatorForChatML (line 127) | class DataCollatorForChatML:
method __post_init__ (line 138) | def __post_init__(self):
method __call__ (line 145) | def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch....
function truncate_right (line 260) | def truncate_right(
function add_bos_token_if_needed (line 292) | def add_bos_token_if_needed(
function add_eos_token_if_needed (line 314) | def add_eos_token_if_needed(
function first_true_indices (line 326) | def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.T...
function get_reward (line 349) | def get_reward(
function prepare_model_for_kbit_training (line 399) | def prepare_model_for_kbit_training(model, use_gradient_checkpointing=Tr...
function enable_gradient_checkpointing (line 436) | def enable_gradient_checkpointing(
function prepare_peft_model (line 465) | def prepare_peft_model(
function pad_to_length (line 520) | def pad_to_length(tensor: torch.Tensor, length: int, pad_value: int | fl...
function empty_cache (line 535) | def empty_cache() -> None:
function peft_module_casting_to_bf16 (line 553) | def peft_module_casting_to_bf16(model):
function create_reference_model (line 571) | def create_reference_model(
FILE: trl/experimental/winrate_callback.py
function _generate_completions (line 42) | def _generate_completions(
function _win_rate_completions_df (line 82) | def _win_rate_completions_df(
class WinRateCallback (line 92) | class WinRateCallback(TrainerCallback):
method __init__ (line 134) | def __init__(
method on_train_begin (line 158) | def on_train_begin(self, args: TrainingArguments, state: TrainerState,...
method on_evaluate (line 226) | def on_evaluate(self, args: TrainingArguments, state: TrainerState, co...
FILE: trl/experimental/xpo/xpo_config.py
class XPOConfig (line 21) | class XPOConfig(OnlineDPOConfig):
method __post_init__ (line 41) | def __post_init__(self):
FILE: trl/experimental/xpo/xpo_trainer.py
class XPOTrainer (line 49) | class XPOTrainer(OnlineDPOTrainer):
method __init__ (line 109) | def __init__(
method alpha (line 180) | def alpha(self):
method _generate_completions (line 187) | def _generate_completions(self, prompts, model):
method _process_completions (line 227) | def _process_completions(self, model_output, ref_output, prompts):
method _compute_rewards (line 254) | def _compute_rewards(self, model_data, ref_data, context_length):
method _compute_judge (line 272) | def _compute_judge(self, model_data, ref_data, context_length):
method _compute_logprobs (line 307) | def _compute_logprobs(self, model, model_data, ref_data, context_length):
method _compute_losses (line 339) | def _compute_losses(
method _log_statistics (line 379) | def _log_statistics(
method training_step (line 462) | def training_step(
FILE: trl/extras/profiling.py
class ProfilingContext (line 30) | class ProfilingContext:
method __init__ (line 75) | def __init__(
method __enter__ (line 90) | def __enter__(self):
method __exit__ (line 95) | def __exit__(self, exc_type, exc_val, exc_tb):
method _log_metrics (line 102) | def _log_metrics(self, duration: float) -> None:
function profiling_context (line 125) | def profiling_context(trainer: Trainer, name: str) -> ProfilingContext:
function profiling_decorator (line 167) | def profiling_decorator(func: Callable) -> Callable:
FILE: trl/generation/vllm_client.py
function pil_to_base64 (line 51) | def pil_to_base64(image):
class VLLMClient (line 58) | class VLLMClient:
method __init__ (line 122) | def __init__(
method check_server (line 168) | def check_server(self, total_timeout: float = 0.0, retry_interval: flo...
method generate (line 204) | def generate(
method chat (line 302) | def chat(
method init_communicator (line 416) | def init_communicator(self, device: torch.device | str | int = 0):
method update_named_param (line 489) | def update_named_param(self, name: str, weights: torch.Tensor):
method update_model_params (line 514) | def update_model_params(self, model: nn.Module):
method reset_prefix_cache (line 526) | def reset_prefix_cache(self):
method close_communicator (line 535) | def close_communicator(self):
FILE: trl/generation/vllm_generation.py
function empty_cache (line 44) | def empty_cache() -> None:
function extract_logprobs (line 62) | def extract_logprobs(all_outputs: list["RequestOutput"]):
class VLLMGeneration (line 103) | class VLLMGeneration:
method __init__ (line 216) | def __init__(
method _init_vllm (line 284) | def _init_vllm(self):
method _fix_param_name_to_vllm (line 369) | def _fix_param_name_to_vllm(self, name: str, extra_prefixes: list[str]...
method _sync_fsdp1_params_to_vllm (line 377) | def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = ...
method _sync_fsdp2_params_to_vllm (line 406) | def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
method sync_weights (line 432) | def sync_weights(self):
method generate (line 521) | def generate(
FILE: trl/import_utils.py
function _is_package_available (line 29) | def _is_package_available(pkg_name: str, return_version: bool = False) -...
function is_deepspeed_available (line 60) | def is_deepspeed_available() -> bool:
function is_fastapi_available (line 64) | def is_fastapi_available() -> bool:
function is_jmespath_available (line 68) | def is_jmespath_available() -> bool:
function is_joblib_available (line 72) | def is_joblib_available() -> bool:
function is_liger_kernel_available (line 76) | def is_liger_kernel_available(min_version: str = LIGER_KERNEL_MIN_VERSIO...
function is_llm_blender_available (line 81) | def is_llm_blender_available() -> bool:
function is_math_verify_available (line 85) | def is_math_verify_available() -> bool:
function is_mergekit_available (line 89) | def is_mergekit_available() -> bool:
function is_pydantic_available (line 93) | def is_pydantic_available() -> bool:
function is_requests_available (line 97) | def is_requests_available() -> bool:
function is_unsloth_available (line 101) | def is_unsloth_available() -> bool:
function is_uvicorn_available (line 105) | def is_uvicorn_available() -> bool:
function is_vllm_available (line 109) | def is_vllm_available(min_version: str | None = None) -> bool:
function is_vllm_ascend_available (line 123) | def is_vllm_ascend_available() -> bool:
function is_weave_available (line 127) | def is_weave_available() -> bool:
class TRLExperimentalWarning (line 131) | class TRLExperimentalWarning(UserWarning):
function suppress_warning (line 138) | def suppress_warning(category):
function suppress_experimental_warning (line 144) | def suppress_experimental_warning():
FILE: trl/models/activation_offloading.py
function _get_unique_tensor_key (line 49) | def _get_unique_tensor_key(tensor: torch.Tensor) -> tuple:
class OffloadActivations (line 79) | class OffloadActivations(saved_tensors_hooks):
method __init__ (line 120) | def __init__(
method update_model_params (line 529) | def update_model_params(self, model: nn.Module):
class NoOpManager (line 562) | class NoOpManager(saved_tensors_hooks):
method __init__ (line 571) | def __init__(self) -> None:
function get_act_offloading_ctx_manager (line 578) | def get_act_offloading_ctx_manager(
FILE: trl/models/utils.py
function remove_hooks (line 47) | def remove_hooks(model: "DeepSpeedEngine") -> None:
function get_all_parameters (line 70) | def get_all_parameters(sub_module, recurse=False):
function iter_params (line 74) | def iter_params(module, recurse=False):
function add_hooks (line 78) | def add_hooks(model: "DeepSpeedEngine") -> None:
function _unwrap_model_for_generation (line 98) | def _unwrap_model_for_generation(
function _override_model_generation_config (line 145) | def _override_model_generation_config(model, generation_kwargs=None):
function unwrap_model_for_generation (line 187) | def unwrap_model_for_generation(
function prepare_deepspeed (line 225) | def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
function prepare_fsdp (line 265) | def prepare_fsdp(model, accelerator: Accelerator) -> FSDP | FSDPModule:
class _ForwardRedirection (line 313) | class _ForwardRedirection:
method __call__ (line 323) | def __call__(
method on_after_inner_forward (line 357) | def on_after_inner_forward(self, wrapper_module: nn.Module, original_m...
method on_after_outer_forward (line 360) | def on_after_outer_forward(self, wrapper_module: nn.Module, original_m...
function disable_gradient_checkpointing (line 365) | def disable_gradient_checkpointing(model: PreTrainedModel, gradient_chec...
function create_reference_model (line 385) | def create_reference_model(
FILE: trl/rewards/accuracy_rewards.py
function accuracy_reward (line 26) | def accuracy_reward(completions: list[list[dict[str, str]]], solution: l...
function reasoning_accuracy_reward (line 97) | def reasoning_accuracy_reward(
FILE: trl/rewards/format_rewards.py
function think_format_reward (line 18) | def think_format_reward(completions: list[list[dict[str, str]]], **kwarg...
FILE: trl/rewards/other_rewards.py
function get_soft_overlong_punishment (line 18) | def get_soft_overlong_punishment(max_completion_len: int, soft_punish_ca...
FILE: trl/scripts/_hf_argparser.py
function string_to_bool (line 40) | def string_to_bool(v):
function make_choice_type_function (line 53) | def make_choice_type_function(choices: list) -> Callable[[str], Any]:
function HfArg (line 68) | def HfArg(
class HfArgumentParser (line 115) | class HfArgumentParser(ArgumentParser):
method __init__ (line 132) | def __init__(self, dataclass_types: DataClassType | Iterable[DataClass...
method _parse_dataclass_field (line 150) | def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses....
method _add_dataclass_arguments (line 255) | def _add_dataclass_arguments(self, dtype: DataClassType):
method parse_args_into_dataclasses (line 276) | def parse_args_into_dataclasses(
method parse_dict (line 362) | def parse_dict(self, args: dict[str, Any], allow_extra_keys: bool = Fa...
method parse_json_file (line 390) | def parse_json_file(self, json_file: str | os.PathLike, allow_extra_ke...
method parse_yaml_file (line 412) | def parse_yaml_file(self, yaml_file: str | os.PathLike, allow_extra_ke...
FILE: trl/scripts/dpo.py
function main (line 69) | def main(script_args, training_args, model_args, dataset_args):
function make_parser (line 149) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr...
FILE: trl/scripts/env.py
function print_env (line 26) | def print_env():
FILE: trl/scripts/grpo.py
class GRPOScriptArguments (line 38) | class GRPOScriptArguments(ScriptArguments):
function main (line 72) | def main(script_args, training_args, model_args, dataset_args):
function make_parser (line 171) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr...
FILE: trl/scripts/kto.py
function main (line 75) | def main(script_args, training_args, model_args, dataset_args):
function make_parser (line 141) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr...
FILE: trl/scripts/reward.py
function main (line 32) | def main(script_args, training_args, model_args, dataset_args):
function make_parser (line 80) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr...
FILE: trl/scripts/rloo.py
class RLOOScriptArguments (line 38) | class RLOOScriptArguments(ScriptArguments):
function main (line 72) | def main(script_args, training_args, model_args, dataset_args):
function make_parser (line 155) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr...
FILE: trl/scripts/sft.py
function main (line 71) | def main(script_args, training_args, model_args, dataset_args):
function make_parser (line 147) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr...
FILE: trl/scripts/utils.py
class DatasetConfig (line 40) | class DatasetConfig:
class DatasetMixtureConfig (line 74) | class DatasetMixtureConfig:
method __post_init__ (line 129) | def __post_init__(self):
class ScriptArguments (line 138) | class ScriptArguments:
function init_zero_verbose (line 197) | def init_zero_verbose():
class TrlParser (line 226) | class TrlParser(HfArgumentParser):
method __init__ (line 274) | def __init__(
method parse_args_and_config (line 295) | def parse_args_and_config(
method set_defaults_with_config (line 351) | def set_defaults_with_config(self, **kwargs) -> list[str]:
function get_git_commit_hash (line 381) | def get_git_commit_hash(package_name):
function get_dataset (line 404) | def get_dataset(mixture_config: DatasetMixtureConfig) -> "DatasetDict":
FILE: trl/scripts/vllm_serve.py
class WeightSyncWorkerExtension (line 34) | class WeightSyncWorkerExtension:
method init_communicator (line 48) | def init_communicator(self, host: str, port: int, world_size: int, cli...
method update_named_param (line 115) | def update_named_param(self, name: str, dtype: str, shape: Sequence[in...
method close_communicator (line 149) | def close_communicator(self) -> None:
class ScriptArguments (line 163) | class ScriptArguments:
function llm_worker (line 310) | def llm_worker(
function chunk_list (line 364) | def chunk_list(lst: list, n: int) -> list[list]:
function main (line 384) | def main(script_args: ScriptArguments):
function make_parser (line 888) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr...
FILE: trl/skills/cli.py
function add_skills_subcommands (line 26) | def add_skills_subcommands(subparsers: argparse._SubParsersAction) -> None:
function cmd_install (line 90) | def cmd_install(args):
function cmd_uninstall (line 147) | def cmd_uninstall(args):
function cmd_list (line 162) | def cmd_list(args):
FILE: trl/skills/skills.py
function list_agent_names (line 50) | def list_agent_names() -> list[str]:
function _get_trl_skills_dir (line 60) | def _get_trl_skills_dir() -> Path:
function resolve_target_path (line 72) | def resolve_target_path(target: str | Path, scope: str = "project") -> P...
function _list_skills_in_dir (line 117) | def _list_skills_in_dir(skills_dir: Path) -> list[str]:
function list_skills (line 138) | def list_skills(target: str | Path | None = None, scope: str = "project"...
function _install_skill_to_dir (line 178) | def _install_skill_to_dir(
function install_skill (line 244) | def install_skill(
function _uninstall_skill_from_dir (line 294) | def _uninstall_skill_from_dir(skill_name: str, target_dir: Path) -> bool:
function uninstall_skill (line 326) | def uninstall_skill(skill_name: str, target: str | Path, scope: str = "p...
FILE: trl/trainer/base_config.py
class _BaseConfig (line 21) | class _BaseConfig(TrainingArguments):
method __post_init__ (line 104) | def __post_init__(self):
FILE: trl/trainer/base_trainer.py
class _BaseTrainer (line 26) | class _BaseTrainer(Trainer):
method create_model_card (line 32) | def create_model_card(
FILE: trl/trainer/callbacks.py
function _generate_completions (line 62) | def _generate_completions(
class SyncRefModelCallback (line 102) | class SyncRefModelCallback(TrainerCallback):
method __init__ (line 107) | def __init__(
method _sync_target_model (line 116) | def _sync_target_model(model, target_model, alpha):
method sync_target_model (line 121) | def sync_target_model(model, target_model, alpha):
method on_step_end (line 134) | def on_step_end(self, args, state, control, **kwargs):
class RichProgressCallback (line 143) | class RichProgressCallback(TrainerCallback):
method __init__ (line 148) | def __init__(self):
method on_train_begin (line 161) | def on_train_begin(self, args, state, control, **kwargs):
method on_step_end (line 174) | def on_step_end(self, args, state, control, **kwargs):
method on_prediction_step (line 181) | def on_prediction_step(self, args, state, control, eval_dataloader=Non...
method on_evaluate (line 190) | def on_evaluate(self, args, state, control, **kwargs):
method on_predict (line 198) | def on_predict(self, args, state, control, **kwargs):
method on_log (line 206) | def on_log(self, args, state, control, logs=None, **kwargs):
method on_train_end (line 239) | def on_train_end(self, args, state, control, **kwargs):
class LogCompletionsCallback (line 254) | class LogCompletionsCallback(TrainerCallback):
method __init__ (line 278) | def __init__(
method on_step_end (line 299) | def on_step_end(self, args, state, control, **kwargs):
class WeaveCallback (line 346) | class WeaveCallback(TrainerCallback):
method __init__ (line 410) | def __init__(
method _initialize_weave (line 438) | def _initialize_weave(self):
method is_evaluation_mode (line 473) | def is_evaluation_mode(self) -> bool:
method on_train_begin (line 477) | def on_train_begin(self, args, state, control, **kwargs):
method on_evaluate (line 481) | def on_evaluate(self, args, state, control, **kwargs):
class BEMACallback (line 575) | class BEMACallback(TrainerCallback):
method __init__ (line 637) | def __init__(
method _unwrap_model (line 666) | def _unwrap_model(model):
method on_train_begin (line 688) | def on_train_begin(
method _ema_beta (line 709) | def _ema_beta(self, step: int) -> float:
method _bema_alpha (line 714) | def _bema_alpha(self, step: int) -> float:
method _update_bema_weights (line 718) | def _update_bema_weights(self, step: int):
method on_step_end (line 731) | def on_step_end(
method on_train_end (line 754) | def on_train_end(self, args: TrainingArguments, state: TrainerState, c...
FILE: trl/trainer/dpo_config.py
class DPOConfig (line 22) | class DPOConfig(_BaseConfig):
method __post_init__ (line 310) | def __post_init__(self):
FILE: trl/trainer/dpo_trainer.py
function get_dataset_column_names (line 89) | def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list...
class DataCollatorForPreference (line 94) | class DataCollatorForPreference(DataCollatorMixin):
method torch_call (line 154) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
class DataCollatorForVisionPreference (line 215) | class DataCollatorForVisionPreference(DataCollatorMixin):
method torch_call (line 302) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
class DPOTrainer (line 406) | class DPOTrainer(_BaseTrainer):
method __init__ (line 501) | def __init__(
method _prepare_dataset (line 823) | def _prepare_dataset(
method _set_signature_columns_if_needed (line 951) | def _set_signature_columns_if_needed(self):
method _precompute_ref_logps (line 975) | def _precompute_ref_logps(self, dataset: Dataset, name: str, batch_siz...
method _truncate_inputs (line 1017) | def _truncate_inputs(
method compute_ref_log_probs (line 1051) | def compute_ref_log_probs(self, inputs):
method _compute_loss_liger (line 1103) | def _compute_loss_liger(self, model, inputs, return_outputs):
method _compute_loss (line 1181) | def _compute_loss(self, model, inputs, return_outputs):
method compute_loss (line 1494) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
method training_step (line 1501) | def training_step(self, *args, **kwargs):
method log (line 1505) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method prediction_step (line 1520) | def prediction_step(self, model, inputs, prediction_loss_only, ignore_...
method _save_checkpoint (line 1532) | def _save_checkpoint(self, model, trial):
FILE: trl/trainer/grpo_config.py
class GRPOConfig (line 22) | class GRPOConfig(_BaseConfig):
method __post_init__ (line 869) | def __post_init__(self):
FILE: trl/trainer/grpo_trainer.py
class _SupportsReset (line 123) | class _SupportsReset(Protocol):
method reset (line 124) | def reset(self, **kwargs) -> str | None: ...
class GRPOTrainer (line 130) | class GRPOTrainer(_BaseTrainer):
method __init__ (line 268) | def __init__(
method _set_signature_columns_if_needed (line 852) | def _set_signature_columns_if_needed(self):
method get_train_dataloader (line 869) | def get_train_dataloader(self):
method _get_train_sampler (line 878) | def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler:
method _get_eval_sampler (line 914) | def _get_eval_sampler(self, eval_dataset) -> Sampler:
method _get_last_hidden_state (line 923) | def _get_last_hidden_state(
method get_high_entropy_mask (line 967) | def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.T...
method _get_per_token_logps_and_entropies (line 1006) | def _get_per_token_logps_and_entropies(
method training_step (line 1081) | def training_step(self, model, inputs, num_items_in_batch):
method _prepare_inputs (line 1093) | def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | A...
method _log_completion_extra (line 1124) | def _log_completion_extra(self, column: str, values: list):
method _log_metric (line 1136) | def _log_metric(self, name: str, value: float):
method _calculate_rewards (line 1150) | def _calculate_rewards(self, inputs, prompts, completions, completion_...
method _tokenize_prompts (line 1241) | def _tokenize_prompts(self, prompts: list):
method _generate_single_turn (line 1284) | def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
method _get_tool_suffix_ids (line 1380) | def _get_tool_suffix_ids(self, tool_messages):
method _tool_call_loop (line 1401) | def _tool_call_loop(self, prompts, prompt_ids, completion_ids, complet...
method _generate (line 1577) | def _generate(self, prompts: list):
method _generate_and_score_completions (line 1692) | def _generate_and_score_completions(
method compute_liger_loss (line 2113) | def compute_liger_loss(self, unwrapped_model, inputs):
method compute_loss (line 2161) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
method get_off_policy_mask (line 2172) | def get_off_policy_mask(
method get_gamma_weights (line 2195) | def get_gamma_weights(
method _compute_loss (line 2243) | def _compute_loss(self, model, inputs):
method prediction_step (line 2442) | def prediction_step(self, model, inputs, prediction_loss_only, ignore_...
method log (line 2450) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method _save_checkpoint (line 2519) | def _save_checkpoint(self, model, trial):
FILE: trl/trainer/kto_config.py
class KTOConfig (line 26) | class KTOConfig(_KTOConfig):
method __post_init__ (line 27) | def __post_init__(self):
FILE: trl/trainer/kto_trainer.py
class KTOTrainer (line 26) | class KTOTrainer(_KTOTrainer):
method __init__ (line 27) | def __init__(self, *args, **kwargs):
FILE: trl/trainer/model_config.py
class ModelConfig (line 19) | class ModelConfig:
method __post_init__ (line 183) | def __post_init__(self):
FILE: trl/trainer/reward_config.py
class RewardConfig (line 22) | class RewardConfig(_BaseConfig):
FILE: trl/trainer/reward_trainer.py
function _suppress_seqcls_cross_arch_keys (line 75) | def _suppress_seqcls_cross_arch_keys(logger: logging.Logger):
function _ignore_seqcls_cross_arch_keys (line 96) | def _ignore_seqcls_cross_arch_keys():
function suppress_seqcls_warning (line 121) | def suppress_seqcls_warning():
function get_dataset_column_names (line 134) | def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list...
class DataCollatorForPreference (line 139) | class DataCollatorForPreference(DataCollatorMixin):
method torch_call (line 200) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
class RewardTrainer (line 229) | class RewardTrainer(_BaseTrainer):
method __init__ (line 321) | def __init__(
method _prepare_dataset (line 537) | def _prepare_dataset(
method _set_signature_columns_if_needed (line 636) | def _set_signature_columns_if_needed(self):
method compute_loss (line 643) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
method training_step (line 685) | def training_step(self, *args, **kwargs):
method log (line 689) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method _save_checkpoint (line 703) | def _save_checkpoint(self, model, trial):
FILE: trl/trainer/rloo_config.py
class RLOOConfig (line 22) | class RLOOConfig(_BaseConfig):
method __post_init__ (line 547) | def __post_init__(self):
FILE: trl/trainer/rloo_trainer.py
class RLOOTrainer (line 104) | class RLOOTrainer(_BaseTrainer):
method __init__ (line 220) | def __init__(
method _set_signature_columns_if_needed (line 596) | def _set_signature_columns_if_needed(self):
method get_train_dataloader (line 613) | def get_train_dataloader(self):
method _get_train_sampler (line 622) | def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler:
method _get_eval_sampler (line 658) | def _get_eval_sampler(self, eval_dataset) -> Sampler:
method _get_per_token_logps_and_entropies (line 667) | def _get_per_token_logps_and_entropies(
method training_step (line 742) | def training_step(self, model, inputs, num_items_in_batch):
method _prepare_inputs (line 754) | def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | A...
method _log_completion_extra (line 785) | def _log_completion_extra(self, column: str, values: list):
method _log_metric (line 797) | def _log_metric(self, name: str, value: float):
method _calculate_rewards (line 811) | def _calculate_rewards(self, inputs, prompts, completions, completion_...
method _tokenize_prompts (line 900) | def _tokenize_prompts(self, prompts: list):
method _generate_single_turn (line 941) | def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
method _generate (line 1031) | def _generate(self, prompts: list):
method _generate_and_score_completions (line 1080) | def _generate_and_score_completions(
method compute_loss (line 1370) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
method _compute_loss (line 1375) | def _compute_loss(self, model, inputs):
method prediction_step (line 1435) | def prediction_step(self, model, inputs, prediction_loss_only, ignore_...
method log (line 1443) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method _save_checkpoint (line 1504) | def _save_checkpoint(self, model, trial):
FILE: trl/trainer/sft_config.py
class SFTConfig (line 23) | class SFTConfig(_BaseConfig):
method __post_init__ (line 269) | def __post_init__(self):
FILE: trl/trainer/sft_trainer.py
function get_dataset_column_names (line 86) | def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list...
class DataCollatorForLanguageModeling (line 91) | class DataCollatorForLanguageModeling(DataCollatorMixin):
method torch_call (line 164) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
method get_position_ids_from_packed_seq_lengths (line 231) | def get_position_ids_from_packed_seq_lengths(batch_seq_lengths: list[l...
class DataCollatorForVisionLanguageModeling (line 259) | class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
method torch_call (line 344) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
method _collate_language_modeling (line 356) | def _collate_language_modeling(self, examples: list[dict[str, Any]]) -...
method _collate_prompt_completion (line 395) | def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -...
function dft_loss (line 494) | def dft_loss(outputs, labels, num_items_in_batch=None):
class SFTTrainer (line 511) | class SFTTrainer(_BaseTrainer):
method __init__ (line 609) | def __init__(
method _prepare_dataset (line 973) | def _prepare_dataset(
method _set_signature_columns_if_needed (line 1198) | def _set_signature_columns_if_needed(self):
method compute_loss (line 1209) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
method prediction_step (line 1339) | def prediction_step(self, model, inputs, prediction_loss_only, ignore_...
method training_step (line 1345) | def training_step(self, *args, **kwargs):
method log (line 1349) | def log(self, logs: dict[str, float], start_time: float | None = None)...
method _save_checkpoint (line 1363) | def _save_checkpoint(self, model, trial):
FILE: trl/trainer/utils.py
function _is_port_free (line 73) | def _is_port_free(port: int, host: str = "127.0.0.1") -> bool:
function _find_free_port (line 83) | def _find_free_port() -> int:
function ensure_master_addr_port (line 93) | def ensure_master_addr_port(addr: str | None = None, port: int | None = ...
function pad (line 114) | def pad(
function disable_dropout_in_model (line 180) | def disable_dropout_in_model(model: torch.nn.Module) -> None:
function get_quantization_config (line 186) | def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConf...
function get_kbit_device_map (line 205) | def get_kbit_device_map() -> dict[str, int] | None:
function get_peft_config (line 212) | def get_peft_config(model_args: ModelConfig) -> "PeftConfig | None":
function prepare_deepspeed (line 238) | def prepare_deepspeed(
function generate_model_card (line 296) | def generate_model_card(
function get_comet_experiment_url (line 378) | def get_comet_experiment_url() -> str | None:
function get_trackio_space_url (line 391) | def get_trackio_space_url() -> str | None:
function log_table_to_comet_experiment (line 412) | def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None:
function flush_left (line 430) | def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Tens...
function flush_right (line 495) | def flush_right(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Ten...
function selective_log_softmax (line 525) | def selective_log_softmax(logits, index) -> torch.Tensor:
function entropy_from_logits (line 572) | def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> ...
function print_prompt_completions_sample (line 609) | def print_prompt_completions_sample(
class RepeatSampler (line 723) | class RepeatSampler(Sampler):
method __init__ (line 772) | def __init__(
method __iter__ (line 794) | def __iter__(self):
method __len__ (line 815) | def __len__(self) -> int:
function nanstd (line 820) | def nanstd(tensor: torch.Tensor, dim: int | tuple[int, ...] | None = Non...
function split_tensor_dict (line 856) | def split_tensor_dict(
function shuffle_sequence_dict (line 891) | def shuffle_sequence_dict(seq_dict: dict[str, Sequence | None]) -> dict[...
function nanmin (line 923) | def nanmin(tensor: torch.Tensor) -> torch.Tensor:
function nanmax (line 938) | def nanmax(tensor: torch.Tensor) -> torch.Tensor:
function identity (line 953) | def identity(x):
function split_pixel_values_by_grid (line 958) | def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[s...
function unsplit_pixel_values_by_grid (line 979) | def unsplit_pixel_values_by_grid(batch: dict[str, torch.Tensor | list[to...
function remove_none_values (line 1000) | def remove_none_values(example: TListOrMapping) -> TListOrMapping:
function create_model_from_path (line 1032) | def create_model_from_path(
function hash_module (line 1071) | def hash_module(module: torch.nn.Module) -> str:
function get_config_model_id (line 1082) | def get_config_model_id(config: PretrainedConfig) -> str:
class CausalLMOutputWithPastAndFlatLogits (line 1098) | class CausalLMOutputWithPastAndFlatLogits(CausalLMOutputWithPast):
function forward_masked_logits (line 1102) | def forward_masked_logits(
function use_adapter (line 1158) | def use_adapter(model: "PeftModel", adapter_name: str | None):
function start_event_loop_in_daemon (line 1197) | def start_event_loop_in_daemon(
function shutdown_event_loop_in_daemon (line 1228) | def shutdown_event_loop_in_daemon(
Condensed preview — 380 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (6,841K chars).
[
{
"path": ".github/ISSUE_TEMPLATE/bug-report.yml",
"chars": 2850,
"preview": "name: \"\\U0001F41B Bug Report\"\ndescription: Submit a bug report to help us improve TRL\nlabels: [ \"bug\" ]\nbody:\n - type: "
},
{
"path": ".github/ISSUE_TEMPLATE/feature-request.yml",
"chars": 1082,
"preview": "name: \"\\U0001F680 Feature request\"\ndescription: Submit a proposal/request for a new TRL feature\nlabels: [ \"Feature reque"
},
{
"path": ".github/ISSUE_TEMPLATE/new-trainer-addition.yml",
"chars": 1181,
"preview": "name: \"\\U0001F31F New trainer addition\"\ndescription: Submit a proposal/request to implement a new trainer for a post-tra"
},
{
"path": ".github/PULL_REQUEST_TEMPLATE.md",
"chars": 1349,
"preview": "# What does this PR do?\n\n<!--\nCongratulations! You've made it this far! You're not quite done yet though.\n\nOnce merged, "
},
{
"path": ".github/codeql/custom-queries.qls",
"chars": 1013,
"preview": "import codeql\n\nfrom WorkflowString interpolation, Workflow workflow\nwhere \n interpolation.getStringValue().matches(\"${{"
},
{
"path": ".github/workflows/build_documentation.yml",
"chars": 401,
"preview": "name: Build documentation\n\non:\n push:\n branches:\n - main\n - doc-builder*\n - v*-release\n\nenv:\n TRL_EX"
},
{
"path": ".github/workflows/build_pr_documentation.yml",
"chars": 510,
"preview": "name: Build PR Documentation\n\non:\n pull_request:\n\nenv:\n TRL_EXPERIMENTAL_SILENCE: 1\n\nconcurrency:\n group: ${{ github."
},
{
"path": ".github/workflows/clear_cache.yml",
"chars": 829,
"preview": "name: \"Cleanup Cache\"\n\non:\n workflow_dispatch:\n schedule:\n - cron: \"0 0 * * *\"\n \njobs:\n cleanup:\n runs-on: u"
},
{
"path": ".github/workflows/codeQL.yml",
"chars": 598,
"preview": "name: \"CodeQL Analysis - Workflows\"\n\non:\n workflow_dispatch:\n\njobs:\n analyze:\n name: \"Analyze GitHub Workflows\"\n "
},
{
"path": ".github/workflows/docker-build.yml",
"chars": 2379,
"preview": "name: Build TRL Docker image\n\non:\n push:\n branches:\n - main\n workflow_dispatch:\n\nconcurrency:\n group: docker-"
},
{
"path": ".github/workflows/issue_auto_labeller.yml",
"chars": 307,
"preview": "name: \"Hugging Face Issue Labeler\"\non:\n issues:\n types: opened\n\njobs:\n triage:\n runs-on: ubuntu-latest\n permi"
},
{
"path": ".github/workflows/pr_style_bot.yml",
"chars": 4736,
"preview": "name: PR Style Bot\n\non:\n workflow_dispatch:\n\n\npermissions:\n contents: write\n pull-requests: write\n\njobs:\n run-style-"
},
{
"path": ".github/workflows/publish.yml",
"chars": 993,
"preview": "name: Publish to PyPI\n\non:\n push:\n branches:\n - main\n - v*-release\n paths:\n - \"VERSION\"\n\njobs:\n p"
},
{
"path": ".github/workflows/slow-tests.yml",
"chars": 2819,
"preview": "name: Slow tests (on push)\n\non:\n push:\n branches: [main]\n paths:\n # Run only when python files are modified\n"
},
{
"path": ".github/workflows/tests-experimental.yml",
"chars": 1708,
"preview": "name: Tests (experimental)\n\non:\n pull_request:\n paths:\n # Run only when relevant files are modified\n - \"tr"
},
{
"path": ".github/workflows/tests.yml",
"chars": 9027,
"preview": "name: Tests\n\non:\n push:\n branches:\n - main\n - ci-*\n pull_request:\n paths:\n # Run only when releva"
},
{
"path": ".github/workflows/tests_latest.yml",
"chars": 1891,
"preview": "name: Tests latest TRL release with dev dependencies\n\non:\n schedule:\n - cron: '0 0 * * *' # Runs daily at midnight "
},
{
"path": ".github/workflows/tests_transformers_branch.yml",
"chars": 3645,
"preview": "name: Tests against Transformers branch\n\non:\n workflow_dispatch:\n inputs:\n transformers_ref:\n descriptio"
},
{
"path": ".github/workflows/trufflehog.yml",
"chars": 454,
"preview": "on:\n push:\n\nname: Secret Leaks\n\njobs:\n trufflehog:\n runs-on: ubuntu-latest\n steps:\n - name: Checkout code\n "
},
{
"path": ".github/workflows/upload_pr_documentation.yml",
"chars": 376,
"preview": "name: Upload PR Documentation\n\non:\n workflow_run:\n workflows: [\"Build PR Documentation\"]\n types:\n - complete"
},
{
"path": ".gitignore",
"chars": 1612,
"preview": "*.bak\n.gitattributes\n.last_checked\n.gitconfig\n*.bak\n*.log\n*~\n~*\n_tmp*\ntmp*\ntags\n\n# Byte-compiled / optimized / DLL files"
},
{
"path": ".pre-commit-config.yaml",
"chars": 488,
"preview": "repos:\n - repo: https://github.com/astral-sh/ruff-pre-commit\n rev: v0.13.3\n hooks:\n - id: ruff-check\n "
},
{
"path": "AGENTS.md",
"chars": 5418,
"preview": "# AGENTS.md\n\n## Repository-specific guidance\n\n### Main code vs experimental code\n\nThe repository is separated into **mai"
},
{
"path": "CITATION.cff",
"chars": 1151,
"preview": "cff-version: 1.2.0\ntitle: 'TRL: Transformers Reinforcement Learning'\nmessage: >-\n If you use this software, please cite"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 5488,
"preview": "\n# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make particip"
},
{
"path": "CONTRIBUTING.md",
"chars": 21100,
"preview": "# How to contribute to TRL?\n\nEveryone is welcome to contribute, and we value everybody's contribution. Code contribution"
},
{
"path": "LICENSE",
"chars": 11355,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "MANIFEST.in",
"chars": 194,
"preview": "include LICENSE\ninclude CONTRIBUTING.md\ninclude README.md\ninclude trl/accelerate_configs/*.yaml\ninclude trl/templates/*."
},
{
"path": "MIGRATION.md",
"chars": 1299,
"preview": "# Migrating from TRL v0 to v1\n\nThis guide covers the breaking changes introduced in TRL v1 and how to update your code. "
},
{
"path": "Makefile",
"chars": 643,
"preview": ".PHONY: test precommit common_tests slow_tests tests_gpu test_experimental\n\ncheck_dirs := examples tests trl\n\nACCELERATE"
},
{
"path": "README.md",
"chars": 7852,
"preview": "# TRL - Transformers Reinforcement Learning\n\n<div style=\"text-align: center\">\n <picture>\n <source media=\"(pref"
},
{
"path": "RELEASE.md",
"chars": 4484,
"preview": "# Making a release\n\n> [!NOTE]\n> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We ne"
},
{
"path": "VERSION",
"chars": 10,
"preview": "1.0.0.dev0"
},
{
"path": "docker/trl/Dockerfile",
"chars": 221,
"preview": "FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel\nRUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lis"
},
{
"path": "docker/trl-dev/Dockerfile",
"chars": 327,
"preview": "FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel\nRUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lis"
},
{
"path": "docs/source/_toctree.yml",
"chars": 3222,
"preview": "- sections:\n - local: index\n title: TRL\n - local: installation\n title: Installation\n - local: quickstart\n ti"
},
{
"path": "docs/source/async_grpo_trainer.md",
"chars": 4526,
"preview": "# Asynchronous GRPO\n\n> [!IMPORTANT]\n> This trainer requires `vllm>=0.17.1` and `transformers>=5.2.0`. For distributed tr"
},
{
"path": "docs/source/bco_trainer.md",
"chars": 4201,
"preview": "# BCO Trainer\n\n[](https://huggingface.co/models?other=bc"
},
{
"path": "docs/source/bema_for_reference_model.md",
"chars": 749,
"preview": "# BEMA for Reference Model\n\nThis feature implements the BEMA algorithm to update the reference model during DPO training"
},
{
"path": "docs/source/callbacks.md",
"chars": 222,
"preview": "# Callbacks\n\n## RichProgressCallback\n\n[[autodoc]] RichProgressCallback\n\n## LogCompletionsCallback\n\n[[autodoc]] LogComple"
},
{
"path": "docs/source/chat_template_utils.md",
"chars": 281,
"preview": "# Chat template utilities\n\n## clone_chat_template\n\n[[autodoc]] clone_chat_template\n\n## is_chat_template_prefix_preservin"
},
{
"path": "docs/source/clis.md",
"chars": 13838,
"preview": "# Command Line Interfaces (CLIs)\n\nTRL provides a powerful command-line interface (CLI) to fine-tune large language model"
},
{
"path": "docs/source/community_tutorials.md",
"chars": 10139,
"preview": "# Community Tutorials\n\nCommunity tutorials are made by active members of the Hugging Face community who want to share th"
},
{
"path": "docs/source/cpo_trainer.md",
"chars": 12545,
"preview": "# CPO Trainer\n\n[](https://huggingface.co/models?other=cp"
},
{
"path": "docs/source/customization.md",
"chars": 3897,
"preview": "# Training customization\n\nTRL is designed with modularity in mind so that users are able to efficiently customize the tr"
},
{
"path": "docs/source/data_utils.md",
"chars": 251,
"preview": "# Data Utilities\n\n## is_conversational\n\n[[autodoc]] is_conversational\n\n## maybe_convert_to_chatml\n\n[[autodoc]] maybe_con"
},
{
"path": "docs/source/dataset_formats.md",
"chars": 42159,
"preview": "# Dataset formats and types\n\nThis guide provides an overview of the dataset formats and types supported by each trainer "
},
{
"path": "docs/source/deepspeed_integration.md",
"chars": 1516,
"preview": "# DeepSpeed Integration\n\n> [!WARNING]\n> Section under construction. Feel free to contribute!\n\nTRL supports training with"
},
{
"path": "docs/source/distributing_training.md",
"chars": 21329,
"preview": "# Distributing Training\n\n> [!WARNING]\n> Section under construction. Feel free to contribute!\n\n## Multi-GPU Training with"
},
{
"path": "docs/source/dpo_trainer.md",
"chars": 19942,
"preview": "# DPO Trainer\n\n[](https://huggingface.co/models?"
},
{
"path": "docs/source/example_overview.md",
"chars": 17074,
"preview": "# Examples\n\nThis directory contains a collection of examples that demonstrate how to use the TRL library for various app"
},
{
"path": "docs/source/experimental_overview.md",
"chars": 1613,
"preview": "# Experimental\n\nThis directory contains a minimal, clearly separated space for fast iteration on new ideas.\n\n> [!WARNING"
},
{
"path": "docs/source/gfpo.md",
"chars": 1653,
"preview": "# GFPO\n\nThis feature implements the GFPO algorithm to enforce concise reasoning in the model's output generation, as pro"
},
{
"path": "docs/source/gkd_trainer.md",
"chars": 5826,
"preview": "# Generalized Knowledge Distillation Trainer\n\n[](https:/"
},
{
"path": "docs/source/gold_trainer.md",
"chars": 7600,
"preview": "# General Online Logit Distillation (GOLD) Trainer\n\n[](https://huggingface.co/models?other="
},
{
"path": "docs/source/grpo_with_replay_buffer.md",
"chars": 1677,
"preview": "# GRPO With Replay Buffer\n\nThis experimental trainer, trains a model with GRPO but replaces groups (and corresponding co"
},
{
"path": "docs/source/gspo_token.md",
"chars": 958,
"preview": "# GSPO-token\n\nIn the paper [Group Sequence Policy Optimization](https://huggingface.co/papers/2507.18071), the authors p"
},
{
"path": "docs/source/index.md",
"chars": 8888,
"preview": "<div style=\"text-align: center\">\n<picture>\n <source media=\"(prefers-color-scheme: light)\" srcset=\"https://huggingface"
},
{
"path": "docs/source/installation.md",
"chars": 824,
"preview": "# Installation\n\nYou can install TRL either from PyPI or from source:\n\n## PyPI\n\nInstall the library with pip or [uv](http"
},
{
"path": "docs/source/jobs_training.md",
"chars": 7133,
"preview": "# Training with Jobs\n\n[](https://huggingface.co/mode"
},
{
"path": "docs/source/judges.md",
"chars": 2745,
"preview": "# Judges\n\n> [!WARNING]\n> TRL Judges is an experimental API which is subject to change at any time. As of TRL v1.0, judge"
},
{
"path": "docs/source/kernels_hub.md",
"chars": 4682,
"preview": "# Kernels Hub Integration and Usage\n\n<img src=\"https://github.com/user-attachments/assets/4b5175f3-1d60-455b-8664-43b249"
},
{
"path": "docs/source/kto_trainer.md",
"chars": 10412,
"preview": "# KTO Trainer\n\n[](https://huggingface.co/models?other=kt"
},
{
"path": "docs/source/liger_kernel_integration.md",
"chars": 2405,
"preview": "# Liger Kernel Integration\n\n[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels d"
},
{
"path": "docs/source/lora_without_regret.md",
"chars": 14716,
"preview": "# LoRA Without Regret\n\nRecent research from the team at [Thinking Machines Lab](https://thinkingmachines.ai/blog/lora/) "
},
{
"path": "docs/source/merge_model_callback.md",
"chars": 87,
"preview": "# MergeModelCallback\n\n[[autodoc]] experimental.merge_model_callback.MergeModelCallback\n"
},
{
"path": "docs/source/minillm_trainer.md",
"chars": 4040,
"preview": "# MiniLLM Trainer\n\n[](https://huggingfac"
},
{
"path": "docs/source/nash_md_trainer.md",
"chars": 9831,
"preview": "# Nash-MD Trainer\n\n[](https://huggingface.co/models"
},
{
"path": "docs/source/nemo_gym.md",
"chars": 10339,
"preview": "# NeMo Gym Integration\n\nNVIDIA NeMo Gym is a library for building RL environments for large language models. This integr"
},
{
"path": "docs/source/online_dpo_trainer.md",
"chars": 14725,
"preview": "# Online DPO Trainer\n\n[](https://huggingface.co/m"
},
{
"path": "docs/source/openenv.md",
"chars": 26203,
"preview": "# OpenEnv Integration for Training LLMs with Environments\n\n[OpenEnv](https://github.com/meta-pytorch/OpenEnv) is an open"
},
{
"path": "docs/source/orpo_trainer.md",
"chars": 9480,
"preview": "# ORPO Trainer\n\n[](https://huggingface.co/models?other="
},
{
"path": "docs/source/paper_index.md",
"chars": 74999,
"preview": "# Paper Index\n\n<!-- Within sections, papers are sorted by publish dates -->\n\n## Group Relative Policy Optimization\n\nPape"
},
{
"path": "docs/source/papo_trainer.md",
"chars": 2427,
"preview": "# PAPO Trainer\n\n[](https://huggingface.co/models?other="
},
{
"path": "docs/source/peft_integration.md",
"chars": 23946,
"preview": "# PEFT Integration\n\nTRL supports [PEFT](https://github.com/huggingface/peft) (Parameter-Efficient Fine-Tuning) methods f"
},
{
"path": "docs/source/ppo_trainer.md",
"chars": 16968,
"preview": "# PPO Trainer\n\n[](https://huggingface.co/models?other=pp"
},
{
"path": "docs/source/prm_trainer.md",
"chars": 6253,
"preview": "# PRM Trainer\n\n[](https://huggingface.co/models?other=pr"
},
{
"path": "docs/source/ptt_integration.md",
"chars": 4524,
"preview": "# Post-Training Toolkit Integration\n\n[Post-Training Toolkit](https://github.com/microsoft/post-training-toolkit) is a di"
},
{
"path": "docs/source/quickstart.md",
"chars": 3343,
"preview": "# Quickstart\n\nTRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-T"
},
{
"path": "docs/source/rapidfire_integration.md",
"chars": 12802,
"preview": "# RapidFire AI Integration\n\nRapidFire AI is an open-source experiment execution framework that enables concurrent traini"
},
{
"path": "docs/source/reducing_memory_usage.md",
"chars": 13225,
"preview": "# Reducing Memory Usage\n\nTraining workflows can often be optimized to **reduce memory consumption**, and TRL provides se"
},
{
"path": "docs/source/reward_trainer.md",
"chars": 10727,
"preview": "# Reward Modeling\n\n[](https://huggingface.co/"
},
{
"path": "docs/source/rewards.md",
"chars": 426,
"preview": "# Reward Functions\n\nThis module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer"
},
{
"path": "docs/source/rloo_trainer.md",
"chars": 33222,
"preview": "# RLOO Trainer\n\n[](https://huggingface.co/models?other="
},
{
"path": "docs/source/script_utils.md",
"chars": 358,
"preview": "# Scripts Utilities\n\n## ScriptArguments\n\n[[autodoc]] ScriptArguments\n\n## TrlParser\n\n[[autodoc]] TrlParser\n - parse_ar"
},
{
"path": "docs/source/sft_trainer.md",
"chars": 16657,
"preview": "# SFT Trainer\n\n[](https://huggingface.co/models?"
},
{
"path": "docs/source/speeding_up_training.md",
"chars": 5968,
"preview": "# Speeding Up Training\n\nThis guide covers various methods to accelerate training in TRL. Each technique includes minimal"
},
{
"path": "docs/source/trackio_integration.md",
"chars": 2113,
"preview": "# Trackio Integration\n\n[Trackio](https://huggingface.co/docs/trackio) is a lightweight, free experiment tracking library"
},
{
"path": "docs/source/unsloth_integration.md",
"chars": 5095,
"preview": "# Unsloth Integration\n\nUnsloth is an open‑source framework for fine‑tuning and reinforcement learning that trains LLMs ("
},
{
"path": "docs/source/use_model.md",
"chars": 2544,
"preview": "# Use model after training\n\nOnce you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you wi"
},
{
"path": "docs/source/vllm_integration.md",
"chars": 18472,
"preview": "# vLLM Integration\n\nThis document will guide you through the process of using vLLM with TRL for faster generation in onl"
},
{
"path": "docs/source/winrate_callback.md",
"chars": 77,
"preview": "# WinRateCallback\n\n[[autodoc]] experimental.winrate_callback.WinRateCallback\n"
},
{
"path": "docs/source/xpo_trainer.md",
"chars": 10176,
"preview": "# XPO Trainer\n\n[](https://huggingface.co/models?other=xp"
},
{
"path": "examples/README.md",
"chars": 113,
"preview": "# Examples\n\nPlease check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.\n"
},
{
"path": "examples/accelerate_configs/alst_ulysses_4gpu.yaml",
"chars": 1497,
"preview": "# ALST/Ulysses Sequence Parallelism with 2D Parallelism (DP + SP) for 4 GPUs\n#\n# This configuration enables 2D paralleli"
},
{
"path": "examples/accelerate_configs/context_parallel_2gpu.yaml",
"chars": 895,
"preview": "# Context Parallelism with FSDP for 2 GPUs\ncompute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndownca"
},
{
"path": "examples/accelerate_configs/deepspeed_zero1.yaml",
"chars": 441,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n deepspeed_multinode_launcher: standard\n gradient_ac"
},
{
"path": "examples/accelerate_configs/deepspeed_zero2.yaml",
"chars": 470,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n deepspeed_multinode_launcher: standard\n offload_opt"
},
{
"path": "examples/accelerate_configs/deepspeed_zero3.yaml",
"chars": 498,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n deepspeed_multinode_launcher: standard\n offload_opt"
},
{
"path": "examples/accelerate_configs/fsdp1.yaml",
"chars": 725,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nfs"
},
{
"path": "examples/accelerate_configs/fsdp2.yaml",
"chars": 627,
"preview": "# Requires accelerate 1.7.0 or higher\ncompute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf"
},
{
"path": "examples/accelerate_configs/multi_gpu.yaml",
"chars": 321,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\ngpu_ids: all\nmachine_ran"
},
{
"path": "examples/accelerate_configs/single_gpu.yaml",
"chars": 316,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: \"NO\"\ndowncast_bf16: 'no'\ngpu_ids: all\nmachine_rank: 0\n"
},
{
"path": "examples/cli_configs/example_config.yaml",
"chars": 405,
"preview": "# This is an example configuration file of TRL CLI, you can use it for \n# SFT like that: `trl sft --config config.yaml -"
},
{
"path": "examples/datasets/deepmath_103k.py",
"chars": 3256,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/hh-rlhf-helpful-base.py",
"chars": 5360,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/llava_instruct_mix.py",
"chars": 4382,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/lm-human-preferences-descriptiveness.py",
"chars": 4882,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/lm-human-preferences-sentiment.py",
"chars": 4560,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/math_shepherd.py",
"chars": 6495,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/prm800k.py",
"chars": 6016,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/rlaif-v.py",
"chars": 4581,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/tldr.py",
"chars": 4233,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/tldr_preference.py",
"chars": 4397,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/ultrafeedback-prompt.py",
"chars": 3470,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/datasets/ultrafeedback.py",
"chars": 5466,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/notebooks/README.md",
"chars": 4879,
"preview": "# Notebooks\n\nThis directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in di"
},
{
"path": "examples/notebooks/grpo_agent.ipynb",
"chars": 26400,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"63ceecbc-87ad-4ad3-a317-f49267ffc93b\",\n \"metadata\": {},\n \"so"
},
{
"path": "examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb",
"chars": 80697,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"lSR2nwdJg962\"\n },\n \"sou"
},
{
"path": "examples/notebooks/grpo_ministral3_vl.ipynb",
"chars": 25033,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"-J8iGzLf4rUJ\"\n },\n \"source\": [\n \"# GRPO"
},
{
"path": "examples/notebooks/grpo_qwen3_vl.ipynb",
"chars": 23081,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"-J8iGzLf4rUJ\"\n },\n \"source\": [\n \"# GRPO"
},
{
"path": "examples/notebooks/grpo_rnj_1_instruct.ipynb",
"chars": 20961,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"-J8iGzLf4rUJ\"\n },\n \"source\": [\n \"# GRPO"
},
{
"path": "examples/notebooks/grpo_trl_lora_qlora.ipynb",
"chars": 67623,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"27ozP4Uy-Cz2\"\n },\n \"sou"
},
{
"path": "examples/notebooks/openenv_sudoku_grpo.ipynb",
"chars": 97703,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"lSR2nwdJg962\"\n },\n \"source\": [\n \"# Open"
},
{
"path": "examples/notebooks/openenv_wordle_grpo.ipynb",
"chars": 105815,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"63ceecbc-87ad-4ad3-a317-f49267ffc93b\""
},
{
"path": "examples/notebooks/sft_ministral3_vl.ipynb",
"chars": 868460,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"UaDIwQOOjgAO\"\n },\n \"source\": [\n \"# Supe"
},
{
"path": "examples/notebooks/sft_nemotron_3.ipynb",
"chars": 33924,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"ovlmqboji0c\",\n \"metadata\": {\n \"id\": \"ovlmqboji0c\"\n },\n "
},
{
"path": "examples/notebooks/sft_qwen_vl.ipynb",
"chars": 860083,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"UaDIwQOOjgAO\"\n },\n \"source\": [\n \"# Supe"
},
{
"path": "examples/notebooks/sft_tool_calling.ipynb",
"chars": 39224,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"ii5Zkit6eSqU\"\n },\n \"source\": [\n \"# Teac"
},
{
"path": "examples/notebooks/sft_trl_lora_qlora.ipynb",
"chars": 42378,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"5oqSnSaqLWAL\"\n },\n \"source\": [\n \"# Supe"
},
{
"path": "examples/scripts/async_grpo.py",
"chars": 2079,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/bco.py",
"chars": 5799,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/cpo.py",
"chars": 3500,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/dpo.py",
"chars": 900,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/dpo_vlm.py",
"chars": 4263,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/evals/judge_tldr.py",
"chars": 4119,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/gkd.py",
"chars": 5023,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/grpo_2048.py",
"chars": 5041,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/grpo_agent.py",
"chars": 10114,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/grpo_vlm.py",
"chars": 5189,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/gspo.py",
"chars": 4405,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/gspo_vlm.py",
"chars": 4825,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/kto.py",
"chars": 3605,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/mpo_vlm.py",
"chars": 4248,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/nash_md.py",
"chars": 5413,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/nemo_gym/README.md",
"chars": 479,
"preview": "# Post-training with NeMo Gym and TRL\n\nThis integration supports training language models in NeMo-Gym environments using"
},
{
"path": "examples/scripts/nemo_gym/config.yaml",
"chars": 898,
"preview": "# Model\nmodel_name: \"Qwen/Qwen2.5-1.5B-Instruct\"\n\n# Data\ndataset_path: \"/home/ubuntu/Gym/resources_servers/workplace_ass"
},
{
"path": "examples/scripts/nemo_gym/deepspeed_zero3.yaml",
"chars": 499,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n deepspeed_multinode_launcher: standard\n offload_opt"
},
{
"path": "examples/scripts/nemo_gym/submit.sh",
"chars": 3517,
"preview": "#!/bin/bash\n#SBATCH -A account\n#SBATCH -p partition\n#SBATCH -N 5\n#SBATCH --gres gpu:8\n#SBATCH --ntasks-per-node=1\n#SBATC"
},
{
"path": "examples/scripts/nemo_gym/train_multi_environment.py",
"chars": 14681,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/online_dpo.py",
"chars": 5688,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/online_dpo_vlm.py",
"chars": 7770,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/openenv/browsergym.py",
"chars": 20278,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/openenv/browsergym_llm.py",
"chars": 16225,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/openenv/carla.py",
"chars": 6922,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/openenv/catch.py",
"chars": 10730,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/openenv/echo.py",
"chars": 2713,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/openenv/sudoku.py",
"chars": 29341,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/openenv/sudoku_prompt.txt",
"chars": 4072,
"preview": "You are an expert Sudoku player with deep knowledge of logical deduction strategies and number placement techniques.\n\n##"
},
{
"path": "examples/scripts/openenv/wordle.py",
"chars": 4152,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/orpo.py",
"chars": 3584,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/ppo/ppo.py",
"chars": 6288,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/ppo/ppo_tldr.py",
"chars": 6801,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/prm.py",
"chars": 4529,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/reward_modeling.py",
"chars": 4398,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/rloo.py",
"chars": 3461,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/rloo_vlm.py",
"chars": 5189,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/sft.py",
"chars": 900,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/sft_gemma3.py",
"chars": 2019,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/sft_gpt_oss.py",
"chars": 3211,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/sft_nemotron_3.py",
"chars": 4086,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/sft_tiny_aya_tool_calling.py",
"chars": 5141,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/sft_video_llm.py",
"chars": 8058,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/sft_vlm.py",
"chars": 3921,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/sft_vlm_gemma3.py",
"chars": 6609,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "examples/scripts/tiny_aya_chat_template.jinja",
"chars": 5590,
"preview": "{{ bos_token }}{% set ns = namespace(system_prompt=false, expect_user=true) %}{% for message in messages %}{% if message"
},
{
"path": "examples/scripts/xpo.py",
"chars": 4838,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "pyproject.toml",
"chars": 5147,
"preview": "[build-system]\nrequires = [\"setuptools >= 77.0.3\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"trl\"\ndescr"
},
{
"path": "requirements.txt",
"chars": 54,
"preview": "accelerate>=1.4.0\ndatasets>=3.0.0\ntransformers>=4.56.2"
},
{
"path": "scripts/add_copyrights.py",
"chars": 3356,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "scripts/generate_harmony_dataset.py",
"chars": 22888,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "scripts/generate_tiny_models.py",
"chars": 18193,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "scripts/generate_toolcall_dataset.py",
"chars": 16851,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "scripts/generate_zen_dataset.py",
"chars": 37894,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "scripts/generate_zen_image_dataset.py",
"chars": 25217,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "scripts/generate_zen_multi_image_dataset.py",
"chars": 22287,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "scripts/log_reports.py",
"chars": 5723,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "tests/__init__.py",
"chars": 611,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "tests/conftest.py",
"chars": 3202,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "tests/data/template.jinja",
"chars": 4168,
"preview": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].con"
},
{
"path": "tests/distributed/__init__.py",
"chars": 612,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "tests/distributed/data/accelerate_configs/ddp.yaml",
"chars": 44,
"preview": "distributed_type: MULTI_GPU\nnum_processes: 2"
},
{
"path": "tests/distributed/data/accelerate_configs/fsdp2.yaml",
"chars": 70,
"preview": "distributed_type: FSDP\nfsdp_config:\n fsdp_version: 2\nnum_processes: 2"
},
{
"path": "tests/distributed/data/accelerate_configs/zero2.yaml",
"chars": 78,
"preview": "distributed_type: DEEPSPEED\ndeepspeed_config:\n zero_stage: 2\nnum_processes: 2"
},
{
"path": "tests/distributed/data/accelerate_configs/zero3.yaml",
"chars": 78,
"preview": "distributed_type: DEEPSPEED\ndeepspeed_config:\n zero_stage: 3\nnum_processes: 2"
},
{
"path": "tests/distributed/test_distributed.py",
"chars": 10352,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "tests/experimental/__init__.py",
"chars": 611,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "tests/experimental/test_async_grpo_trainer.py",
"chars": 5793,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "tests/experimental/test_bco_trainer.py",
"chars": 17246,
"preview": "# Copyright 2020-2026 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
}
]
// ... and 180 more files (download for full content)
About this extraction
This page contains the full source code of the huggingface/trl GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 380 files (6.3 MB), approximately 1.7M tokens, and a symbol index with 1986 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.