Repository: sii-research/siiRL
Branch: main
Commit: 89d8764b6133
Files: 391
Total size: 3.8 MB
Directory structure:
gitextract_ufof6x83/
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README-zh.md
├── README.md
├── docker/
│ ├── Dockerfile.cu124
│ └── Dockerfile.cu126
├── docs/
│ ├── Makefile
│ ├── conf.py
│ ├── examples/
│ │ ├── config.rst
│ │ ├── cpgd_example.rst
│ │ ├── deepscaler_example.rst
│ │ ├── embodied_srpo_example.rst
│ │ ├── megatron_backend_example.rst
│ │ └── mm_eureka_example.rst
│ ├── hardware_tutorial/
│ │ ├── ascend_profiling_en.rst
│ │ ├── ascend_quickstart.rst
│ │ └── metax_quickstart.rst
│ ├── index.rst
│ ├── preparation/
│ │ ├── prepare_data.rst
│ │ └── reward_function.rst
│ ├── programming_guide/
│ │ ├── code_structure.rst
│ │ ├── siiRL_code_explained.rst
│ │ ├── siirl_architecture_guide.rst
│ │ └── srpo_code_explained.rst
│ ├── requirements-docs.txt
│ ├── start/
│ │ ├── install.rst
│ │ └── quickstart.rst
│ └── user_interface/
│ ├── filter_interface.rst
│ ├── metrics_interface.rst
│ ├── pipeline_interface.rst
│ └── reward_interface.rst
├── examples/
│ ├── cpgd_trainer/
│ │ ├── run_qwen2_5-7b.sh
│ │ ├── run_qwen2_5_vl-72b.sh
│ │ ├── run_qwen2_5_vl-7b.sh
│ │ ├── run_qwen3-1.7b.sh
│ │ └── run_qwen3-8b.sh
│ ├── custom_pipeline_example/
│ │ └── custom_grpo.py
│ ├── custom_reward/
│ │ ├── rewardfunc_gsm8k.py
│ │ └── run_qwen2_5-7b-custom_reward.sh
│ ├── dapo_trainer/
│ │ ├── run_qwen2_5-7b.sh
│ │ ├── run_qwen3-235b-megatron-gspo.sh
│ │ └── run_qwen3-8b.sh
│ ├── data_preprocess/
│ │ ├── deepscaler.py
│ │ ├── geo3k.py
│ │ ├── gsm8k.py
│ │ ├── math_dataset.py
│ │ └── mm_eureka.py
│ ├── embodied_srpo_trainer/
│ │ ├── run_openvla_oft_libero_goal.sh
│ │ ├── run_openvla_oft_libero_long.sh
│ │ ├── run_openvla_oft_libero_object.sh
│ │ └── run_openvla_oft_libero_spatial.sh
│ ├── experimental/
│ │ ├── marft/
│ │ │ ├── config/
│ │ │ │ ├── code_env.py
│ │ │ │ ├── math_env.py
│ │ │ │ ├── process.py
│ │ │ │ ├── workflow_marft.yaml
│ │ │ │ └── workflow_marft_code.yaml
│ │ │ └── run_qwen2_5-3b_marft.sh
│ │ └── multiturn_server/
│ │ └── run_qwen2_5-3b_grpo_multiturn_vllm.sh
│ ├── grpo_trainer/
│ │ ├── run_qwen2_5-32b-metax.sh
│ │ ├── run_qwen2_5-32b-npu.sh
│ │ ├── run_qwen2_5-72b-npu.sh
│ │ ├── run_qwen2_5-7b-npu-e2e_prof.sh
│ │ ├── run_qwen2_5-7b-npu-mindspeed.sh
│ │ ├── run_qwen2_5-7b-npu.sh
│ │ ├── run_qwen2_5-7b.sh
│ │ ├── run_qwen2_5_vl-72b.sh
│ │ ├── run_qwen2_5_vl-7b-npu.sh
│ │ ├── run_qwen2_5_vl-7b.sh
│ │ ├── run_qwen3-235b-megatron.sh
│ │ ├── run_qwen3-235b-npu-mindspeed.sh
│ │ ├── run_qwen3-30b-npu-mindspeed.sh
│ │ ├── run_qwen3-8b-megatron.sh
│ │ └── run_qwen3-8b.sh
│ ├── gspo_trainer/
│ │ ├── run_qwen3-1.7b.sh
│ │ ├── run_qwen3-235b-megatron.sh
│ │ └── run_qwen3-30b-gspo-megatron.sh
│ ├── multi_turn/
│ │ ├── config/
│ │ │ ├── interaction_config/
│ │ │ │ └── gsm8k_interaction_config.yaml
│ │ │ └── tool_config/
│ │ │ └── gsm8k_tool_config.yaml
│ │ └── gsm8k/
│ │ └── run_qwen2_5-3b_grpo_multiturn_sglang.sh
│ └── ppo_trainer/
│ ├── run_qwen2_5-72b.sh
│ ├── run_qwen3-8b-megatron.sh
│ └── run_qwen3-8b.sh
├── pyproject.toml
├── requirements-npu.txt
├── requirements.txt
├── setup.py
├── siirl/
│ ├── __init__.py
│ ├── dag_worker/
│ │ ├── __init__.py
│ │ ├── checkpoint_manager.py
│ │ ├── constants.py
│ │ ├── core_algos.py
│ │ ├── dag_utils.py
│ │ ├── dagworker.py
│ │ ├── data_structures.py
│ │ ├── metric_aggregator.py
│ │ ├── metrics_collector.py
│ │ └── validator.py
│ ├── data_coordinator/
│ │ ├── __init__.py
│ │ ├── data_buffer.py
│ │ ├── dataloader/
│ │ │ ├── __init__.py
│ │ │ ├── data_loader_node.py
│ │ │ ├── embodied_preprocess.py
│ │ │ ├── partitioned_dataset.py
│ │ │ └── vision_utils.py
│ │ ├── protocol.py
│ │ └── sample.py
│ ├── engine/
│ │ ├── __init__.py
│ │ ├── actor/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── dp_actor.py
│ │ │ ├── embodied_actor.py
│ │ │ └── megatron_actor.py
│ │ ├── base_worker/
│ │ │ ├── __init__.py
│ │ │ ├── base/
│ │ │ │ ├── __init__.py
│ │ │ │ └── worker.py
│ │ │ ├── megatron/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── npu_mbridge_patch.py
│ │ │ │ └── worker.py
│ │ │ ├── register_center/
│ │ │ │ ├── __init__.py
│ │ │ │ └── register_center.py
│ │ │ └── resouce_pool.py
│ │ ├── critic/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── dp_critic.py
│ │ │ └── megatron_critic.py
│ │ ├── fsdp_workers.py
│ │ ├── megatron_workers.py
│ │ ├── reward_manager/
│ │ │ ├── __init__.py
│ │ │ ├── dapo.py
│ │ │ ├── embodied.py
│ │ │ ├── naive.py
│ │ │ └── parallel.py
│ │ ├── reward_model/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── megatron/
│ │ │ ├── __init__.py
│ │ │ └── reward_model.py
│ │ ├── rollout/
│ │ │ ├── __init__.py
│ │ │ ├── async_server.py
│ │ │ ├── base.py
│ │ │ ├── embodied_rollout.py
│ │ │ ├── hf_rollout.py
│ │ │ ├── schemas.py
│ │ │ ├── sglang_rollout/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── async_sglang_server.py
│ │ │ │ ├── sglang_rollout.py
│ │ │ │ └── utils.py
│ │ │ └── vllm_rollout/
│ │ │ ├── __init__.py
│ │ │ ├── vllm_async_server.py
│ │ │ └── vllm_rollout_spmd.py
│ │ └── sharding_manager/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── fsdp_hf.py
│ │ ├── fsdp_sglang.py
│ │ ├── fsdp_ulysses.py
│ │ ├── fsdp_vllm.py
│ │ ├── megatron_sglang.py
│ │ └── megatron_vllm.py
│ ├── environment/
│ │ └── embodied/
│ │ ├── __init__.py
│ │ ├── adapters/
│ │ │ ├── __init__.py
│ │ │ └── libero.py
│ │ ├── base.py
│ │ └── venv.py
│ ├── execution/
│ │ ├── dag/
│ │ │ ├── __init__.py
│ │ │ ├── builtin_pipelines.py
│ │ │ ├── config_loader.py
│ │ │ ├── node.py
│ │ │ ├── pipeline.py
│ │ │ ├── task_graph.py
│ │ │ └── task_loader.py
│ │ ├── metric_worker/
│ │ │ ├── metric_worker.py
│ │ │ └── utils.py
│ │ ├── rollout_flow/
│ │ │ ├── multi_agent/
│ │ │ │ ├── multiagent_generate.py
│ │ │ │ └── utils.py
│ │ │ └── multiturn/
│ │ │ ├── __init__.py
│ │ │ ├── agent_loop/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── agent_loop.py
│ │ │ │ ├── single_turn_agent_loop.py
│ │ │ │ └── tool_agent_loop.py
│ │ │ ├── interactions/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── gsm8k_interaction.py
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ └── interaction_registry.py
│ │ │ └── tools/
│ │ │ ├── __init__.py
│ │ │ ├── base_tool.py
│ │ │ ├── geo3k_tool.py
│ │ │ ├── gsm8k_tool.py
│ │ │ ├── mcp_base_tool.py
│ │ │ ├── mcp_search_tool.py
│ │ │ ├── sandbox_fusion_tools.py
│ │ │ ├── schemas.py
│ │ │ ├── search_tool.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── mcp_clients/
│ │ │ │ ├── McpClientManager.py
│ │ │ │ ├── __init__.py
│ │ │ │ └── utils.py
│ │ │ ├── search_r1_like_utils.py
│ │ │ └── tool_registry.py
│ │ └── scheduler/
│ │ ├── __init__.py
│ │ ├── enums.py
│ │ ├── graph_updater.py
│ │ ├── launch.py
│ │ ├── process_group_manager.py
│ │ ├── ray_actor_manager.py
│ │ ├── resource_manager.py
│ │ ├── reward.py
│ │ └── task_scheduler.py
│ ├── main_dag.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── embodied/
│ │ │ ├── openvla/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── configuration_prismatic.py
│ │ │ │ ├── modeling_prismatic.py
│ │ │ │ └── processing_prismatic.py
│ │ │ └── openvla_oft/
│ │ │ ├── __init__.py
│ │ │ ├── configuration_prismatic.py
│ │ │ ├── constants.py
│ │ │ ├── modeling_prismatic.py
│ │ │ ├── processing_prismatic.py
│ │ │ └── train_utils.py
│ │ ├── llama/
│ │ │ ├── __init__.py
│ │ │ └── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── llama_loader.py
│ │ │ │ ├── llama_loader_depracated.py
│ │ │ │ └── llama_saver.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── parallel_attention.py
│ │ │ │ ├── parallel_decoder.py
│ │ │ │ ├── parallel_linear.py
│ │ │ │ ├── parallel_mlp.py
│ │ │ │ └── parallel_rmsnorm.py
│ │ │ └── modeling_llama_megatron.py
│ │ ├── loader.py
│ │ ├── mcore/
│ │ │ ├── __init__.py
│ │ │ ├── config_converter.py
│ │ │ ├── loader.py
│ │ │ ├── mbridge.py
│ │ │ ├── model_forward.py
│ │ │ ├── model_forward_fused.py
│ │ │ ├── model_initializer.py
│ │ │ ├── patch_v012.py
│ │ │ ├── registry.py
│ │ │ ├── saver.py
│ │ │ ├── util.py
│ │ │ └── weight_converter.py
│ │ ├── model_utils/
│ │ │ ├── __init__.py
│ │ │ └── visual.py
│ │ ├── patcher.py
│ │ ├── qwen2/
│ │ │ ├── __init__.py
│ │ │ └── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── qwen2_loader.py
│ │ │ │ ├── qwen2_loader_depracated.py
│ │ │ │ └── qwen2_saver.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── parallel_attention.py
│ │ │ │ ├── parallel_decoder.py
│ │ │ │ ├── parallel_linear.py
│ │ │ │ ├── parallel_mlp.py
│ │ │ │ └── parallel_rmsnorm.py
│ │ │ └── modeling_qwen2_megatron.py
│ │ ├── registry.py
│ │ ├── transformers/
│ │ │ ├── __init__.py
│ │ │ ├── internvl.py
│ │ │ ├── internvl_chat/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── configuration_intern_vit.py
│ │ │ │ ├── configuration_internlm2.py
│ │ │ │ ├── configuration_internvl_chat.py
│ │ │ │ ├── modeling_intern_vit.py
│ │ │ │ ├── modeling_internlm2.py
│ │ │ │ ├── modeling_internvl_chat.py
│ │ │ │ ├── tokenization_internlm2.py
│ │ │ │ └── tokenization_internlm2_fast.py
│ │ │ ├── kimi_vl.py
│ │ │ ├── llama.py
│ │ │ ├── monkey_patch.py
│ │ │ ├── npu_patch.py
│ │ │ ├── qwen2.py
│ │ │ ├── qwen2_5_vl.py
│ │ │ ├── qwen2_vl.py
│ │ │ └── transformers_compat.py
│ │ └── weight_loader_registry.py
│ ├── params/
│ │ ├── __init__.py
│ │ ├── dag_args.py
│ │ ├── data_args.py
│ │ ├── display_dict.py
│ │ ├── embodied_args.py
│ │ ├── model_args.py
│ │ ├── parser.py
│ │ ├── profiler_args.py
│ │ └── training_args.py
│ ├── third_party/
│ │ ├── __init__.py
│ │ └── sglang/
│ │ ├── __init__.py
│ │ └── parallel_state.py
│ ├── user_interface/
│ │ ├── filter_interface/
│ │ │ ├── __init__.py
│ │ │ ├── dapo.py
│ │ │ └── embodied.py
│ │ └── rewards_interface/
│ │ └── custom_gsm8k_reward.py
│ └── utils/
│ ├── __init__.py
│ ├── checkpoint/
│ │ ├── __init__.py
│ │ ├── checkpoint_manager.py
│ │ ├── fsdp_checkpoint_manager.py
│ │ └── megatron_checkpoint_manager.py
│ ├── debug/
│ │ ├── __init__.py
│ │ ├── mstx_profile.py
│ │ ├── performance.py
│ │ └── profile.py
│ ├── embodied/
│ │ ├── __init__.py
│ │ ├── libero_utils.py
│ │ ├── openvla_utils.py
│ │ └── video_emb.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ └── torch_functional.py
│ ├── extras/
│ │ ├── __init__.py
│ │ ├── device.py
│ │ ├── fs.py
│ │ ├── hdfs_io.py
│ │ ├── import_utils.py
│ │ ├── misc.py
│ │ ├── net_utils.py
│ │ ├── packages.py
│ │ ├── patch.py
│ │ ├── py_functional.py
│ │ └── ray_utils.py
│ ├── import_string.py
│ ├── kernel/
│ │ ├── __init__.py
│ │ ├── kernels.py
│ │ └── linear_cross_entropy.py
│ ├── logger/
│ │ ├── __init__.py
│ │ ├── aggregate_logger.py
│ │ ├── logging_utils.py
│ │ └── tracking.py
│ ├── megatron/
│ │ ├── __init__.py
│ │ ├── dist_checkpointing.py
│ │ ├── megatron_utils.py
│ │ ├── memory.py
│ │ ├── memory_buffer.py
│ │ ├── optimizer.py
│ │ ├── pipeline_parallel.py
│ │ ├── sequence_parallel.py
│ │ └── tensor_parallel.py
│ ├── memory_utils.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ └── metric_utils.py
│ ├── model_utils/
│ │ ├── __init__.py
│ │ ├── activation_offload.py
│ │ ├── attention_utils.py
│ │ ├── flops_counter.py
│ │ ├── fsdp_utils.py
│ │ ├── model.py
│ │ ├── npu_utils.py
│ │ ├── seqlen_balancing.py
│ │ ├── tensordict_utils.py
│ │ ├── torch_dtypes.py
│ │ ├── torch_functional.py
│ │ ├── ulysses.py
│ │ └── vllm_utils.py
│ └── reward_score/
│ ├── __init__.py
│ ├── embodied.py
│ ├── geo3k.py
│ ├── gsm8k.py
│ ├── math.py
│ ├── math_batch.py
│ ├── math_dapo.py
│ ├── math_verify.py
│ ├── mm_eureka.py
│ ├── prime_code/
│ │ ├── __init__.py
│ │ ├── testing_util.py
│ │ └── utils.py
│ ├── prime_math/
│ │ ├── __init__.py
│ │ ├── grader.py
│ │ └── math_normalize.py
│ ├── sandbox_fusion/
│ │ ├── __init__.py
│ │ └── utils.py
│ └── search_r1_like_qa_em.py
└── tests/
├── __init__.py
├── dag/
│ ├── test_config_loader.py
│ ├── test_node.py
│ ├── test_task_graph.py
│ └── test_task_loader.py
├── dag_worker/
│ ├── test_dag_worker.py
│ ├── test_dapo_merge.py
│ └── test_dapo_pipeline.py
├── data_buffer/
│ ├── detailed_put_performance_test.py
│ ├── performance_test_data_buffer.py
│ └── test_data_buffer.py
└── scheduler/
├── test_process_group_manager.py
└── test_task_scheduler.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
**/*.pt
**/checkpoints
**/wget-log
**/_build/
**/*.ckpt
**/outputs
**/*.tar.gz
**/playground
**/wandb
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
dataset/*
tensorflow/my_graph/*
.idea/
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
tmp/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
.venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
# vscode
.vscode
# Mac
.DS_Store
# vim
*.swp
# ckpt
*.lock
# data
*.parquet
# local logs
logs
log
outputs
.history
*tensorboard
tensorboard/
# version file
siirl/_version.py
================================================
FILE: .pre-commit-config.yaml
================================================
# Default list of files to exclude from checks.
# Add any other paths that should be ignored by all hooks.
exclude: |
(?x)^(
docs/.*|
build/.*
)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
args: [--maxkb=500]
- id: check-case-conflict
- id: check-executables-have-shebangs
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.6
hooks:
- id: ruff
args: ["--fix", "--show-fixes", "--output-format=full"]
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.4.0
hooks:
- id: codespell
args:
- --skip="*.json,*.txt"
- --ignore-words-list=nd,repostory
================================================
FILE: .readthedocs.yaml
================================================
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Set the OS, Python version, and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.11"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/conf.py
# Optionally, but recommended,
# declare the Python requirements required to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: docs/requirements-docs.txt
- method: pip
path: .
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to siiRL
Thank you for considering contributing to siiRL!
We welcome contributions in various forms, including but not limited to:
- Reporting a bug
- Submitting a fix
- Discussing the current state of the code
- Proposing new features
- Becoming a maintainer
- Review pull requests
- Add/Improve documentation
- ...
## Getting Started
To get started, please fork the latest branch.
### Reporting Bugs
If you find a bug, please open an issue on our GitHub repository. When you are creating a bug report, please include as many details as possible. Fill out the required template, detailed information helps us resolve issues faster.
### Suggesting Enhancements
If you have an idea for a new feature or an enhancement to an existing one, please open an issue on our GitHub repository. This allows for a discussion with the community and the project maintainers.
### Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. Issue that pull request!
## Styleguides
### Git Commit Messages
- Use the present tense ("Add feature" not "Added feature").
- Use the imperative mood ("Move A to..." not "Moves A to...").
- Limit the first line to 72 characters or less.
- Reference issues and pull requests liberally after the first line.
## Any questions?
Don't hesitate to contact us if you have any questions. You can reach out to us by opening an issue on GitHub.
We are excited to see your contributions!
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README-zh.md
================================================
siiRL: Shanghai Innovation Institute RL Framework for Advanced LLMs and Multi-Agent Systems
| 📄 论文 |
| 📚 文档 |
|
飞书群
|
微信群
| English |
**siiRL** 是一个新型的、**完全分布式的强化学习 (RL) 框架**,旨在突破大语言模型 (LLM) 后训练中的扩展性瓶颈,并支持未来的多智能体研究,由**上海创智学院**的研究人员开发。
通过移除主流框架中的中心化数据流控制器,siiRL 实现了**近线性的扩展能力**、**显著的吞吐量提升**,通过DAG模块化的设计获得了**极大的的灵活性**,为基于强化学习的 LLM 开发带来了全新的可能性。
---
## 🚀 亮点
+ **近线性扩展能力**: 多控制器模式通过将控制逻辑和数据管理分布到所有工作节点,消除了中心化瓶颈,从而实现了在数千张 GPU 上的近线性扩展。
+ **业界领先的吞吐量 (SOTA)**: 完全分布式的数据流架构最大限度地减少了通信和 I/O 开销,在数据密集型场景中实现了业界领先的吞吐量。
+ **灵活的 DAG 定义流水线**: 将您的算法逻辑与物理硬件解耦。通过 siiRL,您可以将复杂的 RL 工作流定义为一个简单的有向无环图 (DAG),从而实现快速、经济且无需编写代码的实验。
+ **跨硬件兼容性**: siiRL 现已正式支持华为昇腾 (Ascend) NPU,为在不同硬件平台上进行训练和推理提供了高性能的替代方案。
+ **经过验证的性能与稳定性**: 在 7B 到 72B 尺寸的模型上进行了广泛的基准测试,siiRL 在各种任务中均表现出卓越的性能。其优势在长上下文和多模态训练等数据密集型工作负载中尤为明显。
---
## 📰 最新动态
* **[2025/11]**: siiRL 现已支持视觉-语言-动作(VLA)模型训练,基于 [SRPO (Self-Referential Policy Optimization for Vision-Language-Action Models)](https://arxiv.org/pdf/2511.15605) 算法,实现了机器人任务的具身强化学习训练。详细使用方法请参考[文档](/docs/examples/embodied_srpo_example.rst)。
* **[2025/09]**: siiRL 现已集成 Megatron 训练后端,并支持MoE模型训练。其性能已在 Qwen3-MoE 模型(30B、235B)上得到验证。
* **[2025/09]**: siiRL通过与华为昇腾、沐曦科技、阿里云等主要厂商合作,现已支持在其GPU 集群上从 32 卡稳定扩展至 1024 卡,线性扩展效率超过 90%。
* **[2025/09]**: siiRL 支持多智能体与环境之间进行多轮交互。
* **[2025/07]**: siiRL 为 LaMAS 新增了 [MARFT](https://arxiv.org/pdf/2504.16129) 支持,可通过 Flex-POMDP 对 LLM 多智能体进行强化学习微调。
* **[2025/07]**: siiRL 现已支持 [CPGD](https://arxiv.org/pdf/2505.12504v1),这是一种通过正则化大幅度的策略更新来增强 RL 训练稳定性和性能的算法。
* **[2025/07]**: 我们很开心向开源社区发布 siiRL!欢迎查阅我们的[论文](https://arxiv.org/abs/2507.13833),深入了解其架构和评测。
---
## 💡 架构概览
siiRL 是一个为大规模集群设计的完全分布式强化学习框架。siiRL 采用多控制器模式,将所有计算和数据流均匀地分派到每个 GPU。siiRL 由三个主要组件构成:DAG Planner,DAG Workers 和 Data Coordinator.
图 1. siiRL 架构概览。
siiRL 是一个**完全分布式、多控制器的架构**。
关键组件包括:
* **DAG Planner**: 将用户定义的 DAG 转换为序列化、可供每个DAG Worker执行的流水线。
* **DAG Workers**: 核心执行单元,每个DAG Worker绑定到单个 GPU,独立运行其分配的任务。
* **Data Coordinator**: 一组分布式组件(`分布式数据加载器`和`分布式数据缓冲区`),无需中央协调器即可管理从初始加载到中间数据重分配的整个数据生命周期。
### 典型支持的模型与算法
| 模型 |
算法 |
Qwen2.5 系列
- Qwen2.5-7B
- Qwen2.5-72B
- Qwen2.5-VL-7B
- Qwen2.5-VL-72B
Qwen3 系列
- Qwen3-1.7B
- Qwen3-30B
- Qwen3-235B-A22B (MoE)
VLA 模型
|
强化学习算法
|
## 🧪 实验评测
我们对 siiRL 的性能和扩展性进行了全面评测,并与业界领先的 RL 框架 verl 进行了比较。实验表明,siiRL 在所有指标上均表现出卓越的性能。
### 端到端吞吐量
在标准的 PPO 和 GRPO 算法下,siiRL 的吞吐量全面超越了基线系统。特别是在数据密集度更高的 GRPO 算法下,siiRL 通过其完全分布式的架构有效解决了数据瓶颈,实现了高达 **2.62 倍**的性能提升。
图 2: PPO 算法下端到端性能对比
图 3: GRPO 算法下端到端性能对比
### 大规模扩展性
siiRL 展示了近线性的扩展能力,可平滑扩展至 1024 张 GPU。相比之下,基线框架由于其单点数据瓶颈导致的 OOM (内存不足) 错误,在相同条件下运行失败。在基线系统所能支持的最大批量大小下,siiRL 的性能优势可高达 **7 倍**。
图 4: siiRL 的扩展性测试
图 5: 在基线系统最大负载下的性能对比
### 长上下文性能
在处理长上下文任务时,数据传输开销成为主要瓶颈。siiRL 的分布式数据流设计使其性能优势随着上下文长度的增加而愈发明显,实现了高达 **2.03 倍**的吞吐量提升,并成功运行了基线系统无法处理的 72B 模型长上下文任务。
图 6: 长上下文场景下的性能对比
### 模型收敛性
实验证实,siiRL 的性能优化并未以牺牲模型精度为代价。在超参数相同的情况下,siiRL 的奖励和熵收敛曲线与基线系统完全一致,同时将总训练时间**减少了 21%**。
图 7: 模型收敛曲线对比
---
## 📚 相关资源
使用文档
- 安装指南
- 快速入门: 运行 PPO/GRPO
---
## 🗓️ 未来计划
siiRL 仍在积极开发中。我们对未来充满期待,并致力于在两个关键方向上扩展框架的功能:支持真实机器人 VLA 训练和训练推理分离。
### 具身 VLA 训练与真实世界部署
我们正在扩展视觉-语言-动作(VLA)能力,以支持**真实世界机器人部署**。
### 训练-推理分离架构
为增强部署灵活性和资源利用率,我们正在开发**解耦的训练-推理架构**。
---
## 🙏 致谢
我们首先要感谢开源 RL 框架 [verl](https://github.com/volcengine/verl),我们使用它作为评测的主要基线系统。我们特别感谢其分层的 API 设计;我们复用了 verl 中的 `3DParallelWorker` 基类来管理 siiRL 中的系统组件。
siiRL 的构建也离不开其他优秀的开源项目。我们衷心感谢 PyTorch、Ray、vLLM、vLLM-Ascend 和 SGLang 团队的杰出工作。
我们的工作解决了研究过程中发现的扩展性问题并设计了灵活的工作流设计,并希望 siiRL 能为社区的共同进步做出积极贡献。
---
## 🖋️ 如何引用
如果您在研究中发现 siiRL 对您有帮助,请考虑引用我们的论文。
```bibtex
@misc{wang2025distflowfullydistributedrl,
title={DistFlow: A Fully Distributed RL Framework for Scalable and Efficient LLM Post-Training},
author={Zhixin Wang and Tianyi Zhou and Liming Liu and Ao Li and Jiarui Hu and Dian Yang and Jinlong Hou and Siyuan Feng and Yuan Cheng and Yuan Qi},
year={2025},
eprint={2507.13833},
archivePrefix={arXiv},
primaryClass={cs.DC},
url={[https://arxiv.org/abs/2507.13833](https://arxiv.org/abs/2507.13833)},
}
================================================
FILE: README.md
================================================
siiRL: Shanghai Innovation Institute RL Framework for Advanced LLMs and Multi-Agent Systems
| 📄 Paper
| 📚 Documentation
|
Feishu Group
|
Wechat Group
| 🇨🇳 中文 |
**siiRL** is a novel, **fully distributed reinforcement learning (RL) framework** designed to break the scaling barriers in LLM post-training. Developed by researchers from **Shanghai Innovation Institute**, siiRL tackles the critical performance bottlenecks that limit current state-of-the-art systems.
By eliminating the centralized controller common in other frameworks, siiRL delivers **near-linear scalability**, **dramatic throughput gains**, and **unprecedented flexibility** for RL-based LLM development.
---
## 🚀 Highlights
+ **Near-Linear Scalability**: The multi-controller paradigm eliminates central bottlenecks by distributing control logic and data management across all workers, enabling near-linear scalability to thousands of GPUs.
+ **SOTA Throughput**: Fully distributed dataflow architecture minimizes communication and I/O overhead, achieving SOTA throughput in data-intensive scenarios.
+ **Flexible DAG-Defined Pipeline**: Decouple your algorithmic logic from the physical hardware. With siiRL, you can define complex RL workflows as a simple Directed Acyclic Graph (DAG), enabling rapid, cost-effective, and code-free experimentation.
+ **Cross-Hardware Compatibility**: siiRL now officially supports Huawei's Ascend NPUs, providing a high-performance alternative for training and inference on different hardware platforms.
+ **Proven Performance & Stability**: Extensively benchmarked on models from 7B to 72B, siiRL delivering excellent performance across a wide range of tasks. Its advantages are particularly evident in data-intensive workloads such as long-context and multi-modal training.
---
## 📰 News
* **[2025/11]**: siiRL now supports Vision-Language-Action (VLA) model training with [SRPO (Self-Referential Policy Optimization for Vision-Language-Action Models)](https://arxiv.org/pdf/2511.15605), enabling embodied RL training on robotics tasks. See the [documentation](/docs/examples/embodied_srpo_example.rst) for usage instructions.
* **[2025/09]**: Added an explanation of the siiRL [code implementation](/docs/code_explained/siiRL-code-explained.md) for interested users and developers. A [Chinese version](https://zhuanlan.zhihu.com/p/1951768778875605883) is also available on Zhihu.
* **[2025/09]**:siiRL now integrates Megatron training backend with support for MoE training. Performance has been validated on Qwen3-MoE models (30B, 235B).
* **[2025/09]**:siiRL now supports stable scaling on GPU clusters from 32 GPUs up to 1024 GPUs, with over 90% linear scalability efficiency, through collaboration with major manufacturers including Huawei Ascend, MetaX, and Alibaba PPU.
* **[2025/09]**:siiRL supports multi-turn interactions among multi-agents with the environment.
* **[2025/07]**:siiRL adds [MARFT](https://arxiv.org/pdf/2504.16129) support for LaMAS, enabling RL fine-tuning of multi-LLM agents via Flex-POMDP.
* **[2025/07]**: siiRL now supports [CPGD](https://arxiv.org/pdf/2505.12504v1), a novel algorithm that enhances RL training stability and performance by regularizing large policy updates.
* **[2025/07]**: We are excited to release siiRL to the open-source community! Check out our [paper](https://arxiv.org/abs/2507.13833) for a deep dive into the architecture and evaluation.
---
## 💡 Architecture Overview
siiRL is a fully distributed RL framework designed for scalability on large-scale clusters. siiRL employs a multi-controller paradigm that uniformly dispatches all computational and data flow across each GPU. siiRL consists of three main components: a DAG Planner, DAG Workers, and a Data Coordinator.
Figure 1. Overview of siiRL.
siiRL solves this problem with a **fully distributed, multi-controller architecture**.
Key components include:
* **DAG Planner**: Translates a user-defined logical workflow (DAG) into a serialized, executable pipeline for each worker.
* **DAG Workers**: The core execution units, with each worker bound to a single GPU, running its assigned tasks independently.
* **Data Coordinator**: A set of distributed components (`Distributed Dataloader` and `Distributed Databuffer`) that manage the entire data lifecycle, from initial loading to intermediate data redistribution, without a central coordinator.
### Typical Supported Models & Algorithms
| Models |
Algorithms |
Qwen2.5 Series
- Qwen2.5-7B
- Qwen2.5-72B
- Qwen2.5-VL-7B
- Qwen2.5-VL-72B
Qwen3 Series
- Qwen3-1.7B
- Qwen3-30B
- Qwen3-235B-A22B (MoE)
VLA Models
|
Reinforcement Learning
|
## 🧪 Experiment
We conducted a comprehensive evaluation of siiRL's performance and scalability across various scenarios, comparing it with the SOTA RL framework, verl. The experiments demonstrate that siiRL exhibits outstanding performance across all metrics.
### End-to-End Throughput
Under the standard PPO and GRPO algorithms, siiRL's throughput comprehensively surpasses the baseline. Particularly with the more data-intensive GRPO algorithm, siiRL effectively resolves data bottlenecks through its fully distributed architecture, achieving up to a 2.62x performance improvement.
Figure 2: End-to-end performance comparison using the PPO algorithm
Figure 3: End-to-end performance comparison using the GRPO algorithm
### Large-Scale Scalability
siiRL demonstrates near-linear scalability, smoothly extending up to 1024 GPUs. In contrast, the baseline framework fails under identical conditions due to OOM errors caused by its single-point data bottleneck. At the maximum batch size the baseline can support, siiRL's performance advantage can be as high as 7x.
Figure 4: Near-linear scalability of siiRL on VLM models
Figure 5: VLM task performance comparison under the baseline's maximum load
### Long-Context Performance
When processing long-context tasks, data transfer overhead becomes a major bottleneck. siiRL's distributed dataflow design allows its performance advantage to become more pronounced as context length increases, achieving up to a 2.03x throughput improvement and successfully running a 72B model long-context task that the baseline could not handle.
Figure 6: Performance comparison in long-context scenarios
### Model Convergence
Experiments confirm that siiRL's performance optimizations do not come at the cost of model accuracy. With identical hyperparameters, siiRL's reward and entropy convergence curves are identical to the baseline's, while reducing the total training time by 21%.
Figure 7: Model convergence curve comparison
---
## 📚 Resources
Documentation
- Installation
- Quickstart: Running PPO/GRPO
---
## 🗓️ Future Plans
siiRL is under active development. We are excited about the future and are focused on extending the framework's capabilities in two key directions: support training-tnference separation and real-robot VLA training.
### Training-Inference Separation Architecture
To enhance deployment flexibility and resource utilization, we are developing a **decoupled training-inference architecture**.
### Embodied VLA Training & Real-World Deployment
We are expanding our Vision-Language-Action (VLA) capabilities to support **real-world robotics deployment**.
We welcome community contributions! Please see our [Contributing Guide](CONTRIBUTING.md) to get started.
---
## 🙏 Acknowledgement
We would first like to thank the open-source RL framework [verl](https://github.com/volcengine/verl), which we used as a primary baseline for our evaluations. We would like to directly acknowledge its hierarchical API design; we reuse the 3DParallelWorker base class from verl to manage system components in siiRL.
siiRL is also built upon a foundation of other great open-source projects. We would like to thank the teams behind PyTorch, Ray, vLLM, vLLM-Ascend and SGLang for their incredible work.
Our work aims to address the scalability challenges identified during our research, and we hope siiRL can contribute positively to the community's collective progress.
---
## 🖋️ Citation
If you find siiRL useful in your research, please consider citing our paper.
```bibtex
@misc{wang2025distflowfullydistributedrl,
title={DistFlow: A Fully Distributed RL Framework for Scalable and Efficient LLM Post-Training},
author={Zhixin Wang and Tianyi Zhou and Liming Liu and Ao Li and Jiarui Hu and Dian Yang and Jinlong Hou and Siyuan Feng and Yuan Cheng and Yuan Qi},
year={2025},
eprint={2507.13833},
archivePrefix={arXiv},
primaryClass={cs.DC},
url={https://arxiv.org/abs/2507.13833},
}
```
================================================
FILE: docker/Dockerfile.cu124
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM nvcr.io/nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
LABEL maintainer="SII AI Infra Team"
# base environment
RUN apt update \
&& apt install -y rdma-core ibverbs-providers ibverbs-utils \
&& apt install -y python3 python3-pip \
&& ln -sf /usr/bin/python3 /usr/bin/python \
&& python -m pip install -U pip \
&& pip install -U setuptools wheel
# dev tools
RUN apt install -y git cmake ninja-build vim
# python packages
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 \
&& pip install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/ \
&& pip install flash-attn==2.7.3 --no-build-isolation \
&& pip install vllm==0.8.5.post1 \
&& pip install accelerate codetiming datasets dill hydra-core pandas wandb loguru tensorboard qwen_vl_utils \
&& pip install 'ray[default]>=2.47.1' \
&& pip install opentelemetry-exporter-prometheus==0.47b0 \
&& pip install mbridge \
&& pip install numpy==1.26.4
# apex
RUN git clone https://github.com/NVIDIA/apex.git \
&& cd apex \
&& MAX_JOBS=16 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \
&& cd .. && rm -rf apex
# optional: sglang
RUN pip install 'sglang[all]==0.4.6.post5' \
&& pip install xgrammar==0.1.18
# Install TransformerEngine
RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3
# Install Megatron-LM
RUN pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2
================================================
FILE: docker/Dockerfile.cu126
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM nvcr.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04
LABEL maintainer="SII AI Infra Team"
# base environment
RUN apt update \
&& apt install -y rdma-core ibverbs-providers ibverbs-utils libnuma-dev \
&& apt install -y python3 python3-pip \
&& ln -sf /usr/bin/python3 /usr/bin/python \
&& python -m pip install -U pip \
&& pip install -U setuptools wheel
# dev tools
RUN apt install -y git cmake ninja-build vim
# python packages
RUN pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 \
&& pip install flash-attn==2.8.2 --no-build-isolation \
&& pip install vllm==0.10.0 \
&& pip install accelerate codetiming datasets dill hydra-core pandas wandb loguru tensorboard qwen_vl_utils \
&& pip install mbridge \
&& pip install numpy==1.26.4
# apex
RUN git clone https://github.com/NVIDIA/apex.git \
&& cd apex \
&& MAX_JOBS=16 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ \
&& cd .. && rm -rf apex
# optional: sglang
RUN pip install 'sglang[all]==0.4.10.post2' \
&& pip install outlines==1.2.3 xgrammar==0.1.21
# Install TransformerEngine
RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3
# Install Megatron-LM
RUN pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/conf.py
================================================
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "siiRL"
copyright = "2025, SII AI Infra Team"
author = "SII AI Infra Team"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
"myst_parser",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.autosectionlabel",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
]
# Use Google style docstrings instead of NumPy docstrings.
napoleon_google_docstring = True
napoleon_numpy_docstring = False
# Make autosectionlabel use document name as prefix to avoid duplicate label warnings
autosectionlabel_prefix_document = True
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
source_suffix = {
".rst": "restructuredtext",
".md": "markdown",
}
templates_path = ["_templates"]
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "plan_*.md"]
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
================================================
FILE: docs/examples/config.rst
================================================
.. _config-explain-page:
===================
Configuration Guide
===================
siiRL uses Hydra-based configuration management with dataclass parameters. All configuration parameters are defined in the ``siirl/params/`` directory and can be set via command-line arguments.
Configuration Structure
-----------------------
Parameters are organized into the following modules:
- ``DataArguments``: Data-related parameters (``siirl/params/data_args.py``)
- ``ActorRolloutRefArguments``: Actor, Rollout, and Reference model parameters (``siirl/params/model_args.py``)
- ``CriticArguments``: Critic model parameters (``siirl/params/model_args.py``)
- ``RewardModelArguments``: Reward model parameters (``siirl/params/model_args.py``)
- ``AlgorithmArguments``: RL algorithm parameters (``siirl/params/model_args.py``)
- ``TrainingArguments``: Training configuration (``siirl/params/training_args.py``)
- ``DAGArguments``: DAG workflow parameters (``siirl/params/dag_args.py``)
- ``ProfilerArguments``: Profiling parameters (``siirl/params/profiler_args.py``)
All parameters are combined into the ``SiiRLArguments`` class.
Usage
-----
Parameters are set via command-line arguments using dot notation:
.. code-block:: bash
python -m siirl.main_dag \
data.train_files=/path/to/train.parquet \
data.train_batch_size=512 \
actor_rollout_ref.model.path=/path/to/model \
algorithm.adv_estimator=grpo \
trainer.total_epochs=30
Data Parameters
---------------
Location: ``siirl/params/data_args.py``
.. code-block:: bash
data.tokenizer=null
data.train_files=/path/to/train.parquet
data.val_files=/path/to/val.parquet
data.prompt_key=prompt
data.max_prompt_length=512
data.max_response_length=512
data.train_batch_size=1024
data.return_raw_input_ids=False
data.return_raw_chat=False
data.return_full_prompt=False
data.shuffle=True
data.filter_overlong_prompts=False
data.filter_overlong_prompts_workers=1
data.truncation=error
data.image_key=images
data.trust_remote_code=True
**Key Parameters:**
- ``data.train_files``: Training data file path (Parquet format, can be list or single file)
- ``data.val_files``: Validation data file path
- ``data.prompt_key``: Field name for prompt in dataset (default: "prompt")
- ``data.max_prompt_length``: Maximum prompt length (left-padded)
- ``data.max_response_length``: Maximum response length for rollout generation
- ``data.train_batch_size``: Training batch size per iteration
- ``data.return_raw_input_ids``: Return original input_ids without chat template (for different RM chat templates)
- ``data.shuffle``: Whether to shuffle data
- ``data.truncation``: Truncation strategy ("error", "left", "right", "middle")
- ``data.trust_remote_code``: Allow remote code execution for tokenizers
Custom Dataset
~~~~~~~~~~~~~~
.. code-block:: bash
data.custom_cls.path=/path/to/custom_dataset.py
data.custom_cls.name=MyDatasetClass
- ``data.custom_cls.path``: Path to custom dataset class file
- ``data.custom_cls.name``: Name of the dataset class
Actor/Rollout/Reference Model
------------------------------
Location: ``siirl/params/model_args.py``
Model Configuration
~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
actor_rollout_ref.hybrid_engine=True
actor_rollout_ref.model.path=/path/to/model
actor_rollout_ref.model.external_lib=null
actor_rollout_ref.model.enable_gradient_checkpointing=False
actor_rollout_ref.model.enable_activation_offload=False
actor_rollout_ref.model.trust_remote_code=False
actor_rollout_ref.model.use_remove_padding=False
- ``actor_rollout_ref.model.path``: Huggingface model path (local or HDFS)
- ``actor_rollout_ref.model.external_lib``: Additional Python packages to import
- ``actor_rollout_ref.model.enable_gradient_checkpointing``: Enable gradient checkpointing
- ``actor_rollout_ref.model.enable_activation_offload``: Enable activation offloading
- ``actor_rollout_ref.model.trust_remote_code``: Allow remote code model loading
- ``actor_rollout_ref.model.use_remove_padding``: Remove padding tokens for efficiency
Actor Configuration
~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
actor_rollout_ref.actor.strategy=fsdp
actor_rollout_ref.actor.ppo_mini_batch_size=256
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8
actor_rollout_ref.actor.grad_clip=1.0
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.entropy_coeff=0.0
actor_rollout_ref.actor.use_kl_loss=False
actor_rollout_ref.actor.kl_loss_coef=0.001
actor_rollout_ref.actor.ppo_epochs=1
actor_rollout_ref.actor.optim.lr=1e-6
- ``actor.strategy``: Backend strategy ("fsdp" or "megatron")
- ``actor.ppo_mini_batch_size``: Mini-batch size for PPO updates (global across GPUs)
- ``actor.ppo_micro_batch_size_per_gpu``: Micro-batch size per GPU (gradient accumulation)
- ``actor.grad_clip``: Gradient clipping threshold
- ``actor.clip_ratio``: PPO clip ratio
- ``actor.use_kl_loss``: Enable KL loss in actor
- ``actor.kl_loss_coef``: KL loss coefficient (for GRPO)
- ``actor.optim.lr``: Learning rate
Reference Model
~~~~~~~~~~~~~~~
.. code-block:: bash
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16
actor_rollout_ref.ref.fsdp_config.param_offload=False
- ``ref.log_prob_micro_batch_size_per_gpu``: Micro-batch size for reference log prob computation
- ``ref.fsdp_config.param_offload``: Enable parameter offloading (recommended for models > 7B)
Rollout Configuration
~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.temperature=1.0
actor_rollout_ref.rollout.top_k=-1
actor_rollout_ref.rollout.top_p=1.0
actor_rollout_ref.rollout.tensor_model_parallel_size=2
actor_rollout_ref.rollout.gpu_memory_utilization=0.5
actor_rollout_ref.rollout.n=8
- ``rollout.name``: Rollout backend ("vllm", "sglang", "hf")
- ``rollout.temperature``: Sampling temperature
- ``rollout.top_k``: Top-k sampling (-1 for vLLM, 0 for HF)
- ``rollout.top_p``: Top-p sampling
- ``rollout.tensor_model_parallel_size``: Tensor parallelism size (vLLM only)
- ``rollout.gpu_memory_utilization``: GPU memory fraction for vLLM
- ``rollout.n``: Number of responses per prompt (>1 for GRPO/RLOO)
Critic Model
------------
Location: ``siirl/params/model_args.py``
.. code-block:: bash
critic.enable=True
critic.model.path=/path/to/critic_model
critic.ppo_mini_batch_size=256
critic.ppo_micro_batch_size_per_gpu=8
critic.optim.lr=1e-5
Most parameters are similar to Actor configuration.
Reward Model
------------
Location: ``siirl/params/model_args.py``
.. code-block:: bash
reward_model.enable=False
reward_model.model.path=/path/to/reward_model
reward_model.model.input_tokenizer=null
reward_model.micro_batch_size_per_gpu=16
reward_model.reward_manager=naive
- ``reward_model.enable``: Enable reward model (False = use only custom reward functions)
- ``reward_model.model.input_tokenizer``: Input tokenizer path (if different from policy)
- ``reward_model.reward_manager``: Reward manager type ("naive", "batch", "parallel", "dapo", "embodied")
Custom Reward Function
~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
custom_reward_function.path=/path/to/my_reward.py
custom_reward_function.name=compute_score
- ``custom_reward_function.path``: Path to custom reward function file
- ``custom_reward_function.name``: Function name (default: "compute_score")
See :doc:`/user_interface/reward_interface` for details.
Algorithm Parameters
--------------------
Location: ``siirl/params/model_args.py``
.. code-block:: bash
algorithm.gamma=1.0
algorithm.lam=1.0
algorithm.adv_estimator=grpo
algorithm.use_kl_in_reward=False
algorithm.kl_penalty=kl
algorithm.kl_ctrl.type=fixed
algorithm.kl_ctrl.kl_coef=0.005
algorithm.workflow_type=DEFAULT
- ``algorithm.gamma``: Discount factor
- ``algorithm.lam``: GAE lambda (bias-variance tradeoff)
- ``algorithm.adv_estimator``: Advantage estimator ("gae", "grpo", "cpgd", "gspo", "rloo")
- ``algorithm.use_kl_in_reward``: Enable KL penalty in reward
- ``algorithm.kl_penalty``: KL divergence calculation method ("kl", "abs", "mse", "low_var_kl", "full")
- ``algorithm.workflow_type``: Workflow type ("DEFAULT", "DAPO", "EMBODIED")
Training Parameters
-------------------
Location: ``siirl/params/training_args.py``
.. code-block:: bash
trainer.total_epochs=30
trainer.project_name=siirl_examples
trainer.experiment_name=gsm8k
trainer.logger=['console', 'wandb']
trainer.nnodes=1
trainer.n_gpus_per_node=8
trainer.save_freq=10
trainer.val_before_train=True
trainer.test_freq=2
- ``trainer.total_epochs``: Number of training epochs
- ``trainer.project_name``: Project name (for logging)
- ``trainer.experiment_name``: Experiment name (for logging)
- ``trainer.logger``: Logger types (["console", "wandb", "tensorboard", "mlflow"])
- ``trainer.nnodes``: Number of nodes
- ``trainer.n_gpus_per_node``: Number of GPUs per node
- ``trainer.save_freq``: Checkpoint saving frequency (by iteration)
- ``trainer.val_before_train``: Run validation before training
- ``trainer.test_freq``: Validation frequency (by iteration)
DAG Parameters
--------------
Location: ``siirl/params/dag_args.py``
.. code-block:: bash
dag.custom_pipeline_fn=null
- ``dag.custom_pipeline_fn``: Custom pipeline function path (e.g., "module:function")
See :doc:`/user_interface/pipeline_interface` for custom pipeline details.
Complete Example
----------------
GRPO Training
~~~~~~~~~~~~~
.. code-block:: bash
python -m siirl.main_dag \
algorithm.adv_estimator=grpo \
algorithm.workflow_type=DEFAULT \
data.train_files=/path/to/gsm8k/train.parquet \
data.train_batch_size=512 \
data.max_prompt_length=2048 \
data.max_response_length=4096 \
actor_rollout_ref.model.path=/path/to/model \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.n=8 \
custom_reward_function.path=siirl/user_interface/rewards_interface/custom_gsm8k_reward.py \
custom_reward_function.name=compute_score \
trainer.total_epochs=30 \
trainer.n_gpus_per_node=8 \
trainer.save_freq=10
PPO Training
~~~~~~~~~~~~
.. code-block:: bash
python -m siirl.main_dag \
algorithm.adv_estimator=gae \
critic.enable=True \
data.train_files=/path/to/data.parquet \
actor_rollout_ref.model.path=/path/to/model \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.rollout.name=vllm \
critic.optim.lr=1e-5 \
trainer.total_epochs=30
DAPO Training
~~~~~~~~~~~~~
.. code-block:: bash
python -m siirl.main_dag \
algorithm.workflow_type=DAPO \
algorithm.adv_estimator=grpo \
algorithm.filter_groups.enable=True \
algorithm.filter_groups.metric=seq_final_reward \
data.train_files=/path/to/data.parquet \
actor_rollout_ref.model.path=/path/to/model \
trainer.total_epochs=30
Parameter Reference
-------------------
For the complete parameter definitions, see:
- ``siirl/params/data_args.py`` - Data parameters
- ``siirl/params/model_args.py`` - Model, algorithm parameters
- ``siirl/params/training_args.py`` - Training parameters
- ``siirl/params/dag_args.py`` - DAG workflow parameters
- ``siirl/params/profiler_args.py`` - Profiler parameters
================================================
FILE: docs/examples/cpgd_example.rst
================================================
DeepScaleR Example with CPGD
==============================
Introduction
------------
This example demonstrates how to fine-tune a Large Language Model for advanced mathematical reasoning on the **DeepScaleR** dataset using **Clipped Policy Gradient Optimization with Policy Drift (CPGD)**, a novel reinforcement learning algorithm designed for enhanced training stability.
**Paper:** `CPGD: Toward Stable Rule-based Reinforcement Learning for Language Models `__
**Dataset:** https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset
While algorithms like PPO and GRPO are powerful, they can sometimes suffer from instability due to their reliance on importance-sampling ratios in the loss function. CPGD is proposed to mitigate these issues by providing a more stable policy update mechanism, making it a robust choice for complex reasoning tasks.
CPGD Algorithm Overview
-----------------------
CPGD enhances training stability by making two key modifications to the standard policy gradient approach:
1. **Clipped Policy Gradient Objective**: Instead of directly using the policy ratio in the loss (which can cause high variance), CPGD uses a policy gradient objective. It then applies a clipping mechanism to the *logarithm* of the policy ratio. This prevents excessive policy updates when the ratio becomes too large, effectively keeping the optimization within a trusted region.
2. **Policy Drift Regularization**: CPGD introduces a *policy drift* term, which is a KL divergence penalty between the current policy and the old policy from the start of the training iteration. This acts as a dynamic regularizer, pulling the policy back if it strays too far, too quickly, thus preventing training collapse.
Together, these features allow CPGD to achieve consistent performance improvements while avoiding the instability often seen in other RL algorithms.
Step 1: Prepare the Dataset
---------------------------
The data preparation process is identical to other examples using this dataset. First, preprocess the DeepScaleR dataset into the required Parquet format.
.. code:: bash
cd examples/data_preprocess
python3 deepscaler.py --local_dir ~/data/deepscaler
This command downloads, processes, and saves the training and testing sets in the `~/data/deepscaler` directory.
Step 2: Download the Pre-trained Model
--------------------------------------
You need a base model to start the CPGD training. In this example, we use `Qwen2.5-7B-Instruct`.
- **Recommended: Download via CLI:** Use a tool like `huggingface-cli` to download the model to a local directory.
.. code:: bash
huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir ~/data/models/Qwen2.5-7B-Instruct
- **Automatic Download:** You can also specify the model name directly in the `actor_rollout_ref.model.path` field of the run script, and the framework will download it automatically.
Step 3: Perform CPGD Training
-----------------------------
With the data and model ready, you can now launch the training job using the CPGD algorithm.
**Reward Function**
For this task, we use the same rule-based reward function as in the PPO/GRPO examples. The framework's default reward mechanism performs an exact match on the final answer within the `\\boxed{...}` block. A correct answer receives a positive reward, and an incorrect one receives zero.
**Training Script**
Below is a complete training script from `examples/cpgd_trainer/run_qwen2_5-7b.sh`. It is configured to use the CPGD algorithm (`algorithm.adv_estimator=cpgd`). Note the presence of CPGD-specific parameters like `actor_rollout_ref.actor.policy_drift_coeff` and `algorithm.weight_factor_in_cpgd`.
.. literalinclude:: ../../examples/cpgd_trainer/run_qwen2_5-7b.sh
:language: bash
:caption: examples/cpgd_trainer/run_qwen2_5-7b.sh
================================================
FILE: docs/examples/deepscaler_example.rst
================================================
DeepScaleR Example with PPO
=============================
Introduction
------------
This example demonstrates how to fine-tune a Large Language Model for advanced mathematical reasoning using the **DeepScaleR** dataset.
**Paper:** https://pretty-radio-b75.notion.site/DeepScaleR-Surpassing-O1-Preview-with-a-1-5B-Model-by-Scaling-RL-19681902c1468005bed8ca303013a4e2.
**Dataset:** https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset
The core idea is to leverage Reinforcement Learning (RL), specifically Proximal Policy Optimization (PPO), to teach the model not just to find the correct answer, but to follow a logical, step-by-step reasoning process. This is achieved by rewarding the model based on the correctness of its final answer, which is extracted from a structured output.
Dataset Overview
----------------
The DeepScaleR dataset consists of challenging mathematical problems. Each sample includes a question (`problem`), a detailed reasoning path (`solution`), and a final answer enclosed in a `\\boxed{}` block (`answer`).
**An example from DeepScaleR:**
**Prompt:**
"Let $a_n=6^{n}+8^{n}$. Determine the remainder upon dividing $a_ {83}$ by $49$."
**Solution:**
"$6^{83} + 8^{83} = (6+8)(6^{82}-6^{81}8+\\ldots-8^{81}6+8^{82})$\n Becuase $7|(6+8)$, we only consider $6^{82}-6^{81}8+\\ldots-8^{81}6+8^{82} \\pmod{7}$\n$6^{82}-6^{81}8+\\ldots-8^{81}6+8^{82} \\equiv (-1)^{82} - (-1)^{81}+ \\ldots - (-1)^1 + 1 = 83 \\equiv 6 \\pmod{7}$\n$6^{83} + 8^{83} \\equiv 14 \\cdot 6 \\equiv \\boxed{035} \\pmod{49}$"
**Answer:**
`35`
Step 1: Prepare the Dataset
---------------------------
First, preprocess the DeepScaleR dataset into the required Parquet format. Our framework includes a script for this purpose.
.. code:: bash
cd examples/data_preprocess
python3 deepscaler.py --local_dir ~/data/deepscaler
This will download the dataset from Hugging Face, process it, and save `train.parquet` and `test.parquet` files in the `~/data/deepscaler` directory.
Step 2: Download the Pre-trained Model
--------------------------------------
You need a base model to start the PPO training. In this example, we use `Qwen2.5-7B-Instruct`. There are several ways to make the model available to the trainer:
- **Recommended: Download via CLI:** Use tools like `huggingface-cli` or `modelscope` to download the model to a local directory. This gives you more control.
.. code:: bash
# For Hugging Face
huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir ~/data/models/Qwen2.5-7B-Instruct --local-dir-use-symlinks False
# For ModelScope
modelscope download Qwen/Qwen2.5-7B-Instruct --local_dir ~/data/models/Qwen2.5-7B-Instruct
- **Automatic Download:** You can also specify the Hugging Face model name (e.g., `Qwen/Qwen2.5-7B-Instruct`) directly in the `actor_rollout_ref.model.path` and `critic.model.path` fields of your run script. The framework will attempt to download it automatically on the first run.
Step 3: Perform PPO Training
----------------------------
With the data and model ready, you can now launch the PPO training job.
**Reward Function**
For this task, we use a simple but effective rule-based reward function. The framework's default reward mechanism will be used, which performs an exact match between the model's generated answer and the `ground_truth` from the dataset.
- The model is prompted to provide its final answer inside a `\\boxed{...}` block.
- The reward function checks if the content inside the generated `\\boxed{}` matches the ground truth answer.
- A correct match receives a positive reward (e.g., 1.0), while an incorrect match or a malformed response receives zero reward.
**Training Script**
Below is a complete training script based on `examples/ppo_trainer/run_qwen3-8b.sh`. It is configured for a single-node, multi-GPU setup. You should adapt paths like `HOME` to your environment.
.. literalinclude:: ../../examples/ppo_trainer/run_qwen3-8b.sh
:language: bash
:caption: examples/ppo_trainer/run_qwen2_5-7b.sh
================================================
FILE: docs/examples/embodied_srpo_example.rst
================================================
Embodied SRPO Training
======================
Introduction
------------
This guide explains how to perform Embodied AI training using the SRPO algorithm with OpenVLA-oft models on tasks such as LIBERO. Embodied AI training involves an agent interacting with an environment, where the rewards are often based on task success.
This example demonstrates how to perform RL training on an `OpenVLA-oft-7B` model using the SRPO algorithm on the `libero_long` benchmark.
Step 1: Prepare the Environment
-------------------------------
You should use the provided Docker image for Embodied AI training, which contains all necessary dependencies including EGL support for rendering.
**Docker Image**: ``siiai/siirl-vla:libero-egl-cu12.6`` (Available at `Docker Hub `_)
Ensure you have the necessary environment variables set. This includes the path to the `siiRL` repository and any other dependencies.
.. code:: bash
export SIIRL_DIR="/path/to/siiRL"
export VJEPA2_DIR="$HOME/code/vjepa2" # V-JEPA 2 code repository (https://github.com/facebookresearch/vjepa2)
export PYTHONPATH="$SIIRL_DIR:/path/to/LIBERO:$VJEPA2_DIR:$PYTHONPATH"
Step 2: Prepare the Models
--------------------------
You need the following models:
1. **SFT Model**: A Supervised Fine-Tuned (SFT) OpenVLA-oft model. You should select the model that corresponds to your specific task. For example, if you are training on `libero_long`, you should use the `Sylvest/OpenVLA-AC-PD-1traj-libero-long` model.
Here are the recommended Hugging Face models from the `Sylvest collection `_:
- `Sylvest/OpenVLA-AC-PD-1traj-libero-object` (for `libero_object`)
- `Sylvest/OpenVLA-AC-PD-1traj-libero-spatial` (for `libero_spatial`)
- `Sylvest/OpenVLA-AC-PD-1traj-libero-goal` (for `libero_goal`)
- `Sylvest/OpenVLA-AC-PD-1traj-libero-long` (for `libero_long`)
2. **Visual Encoder**: A visual encoder model V-JEPA is **required** for processing visual observations.
- First, clone the V-JEPA 2 code repository from GitHub (`facebookresearch/vjepa2 `_):
.. code:: bash
git clone https://github.com/facebookresearch/vjepa2.git $HOME/code/vjepa2
Make sure to add the V-JEPA 2 directory to your ``PYTHONPATH`` as shown in Step 1.
- Then, download the V-JEPA 2 model weights from Hugging Face: `Sylvest/vjepa2-vit-g `_
.. code:: bash
huggingface-cli download Sylvest/vjepa2-vit-g --local-dir $HOME/models/vjepa2
Set the paths to these resources in your environment or script:
.. code:: bash
export MODEL_PATH=$HOME/models/Sylvest/OpenVLA-AC-PD-1traj-libero-long
export VJEPA_MODEL_PATH=$HOME/models/vjepa2/vitg-384.pt
.. note::
You do not need to manually prepare a dataset file. ``siiRL`` will automatically generate the task manifest (Parquet files) based on the environment configuration and save them to the path specified in ``TRAIN_DATA_PATH`` and ``TEST_DATA_PATH``.
Step 3: Configure and Run the Training Script
---------------------------------------------
Embodied AI training requires specific configurations to handle the environment interaction and action spaces.
Key Configuration Parameters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
**Embodied Specifics**:
- ``actor_rollout_ref.embodied.embodied_type``: The model type (e.g., ``openvla-oft``).
- ``actor_rollout_ref.embodied.action_token_len``: The dimensionality of the action space (e.g., 7 for xyz + quaternion + gripper).
- ``actor_rollout_ref.embodied.action_chunks_len``: The number of action steps predicted in one forward pass.
- ``actor_rollout_ref.embodied.video_embedding_model_path``: Path to the V-JEPA 2 video embedding model (e.g., ``$VJEPA_MODEL_PATH``).
**Environment Configuration**:
- ``actor_rollout_ref.embodied.env.env_type``: The environment library (e.g., ``libero``).
- ``actor_rollout_ref.embodied.env.env_name``: The specific task suite name (e.g., ``libero_long``).
- ``actor_rollout_ref.embodied.env.num_envs``: Number of parallel environments per rollout worker. Default is 16 environments per GPU, and it is not recommended to exceed 16.
- ``actor_rollout_ref.embodied.env.max_steps``: Maximum steps per episode.
**Algorithm Adjustments**:
- ``algorithm.embodied_sampling.filter_accuracy``: Enable filtering of prompts based on estimated success rate.
- ``algorithm.embodied_sampling.accuracy_lower_bound``: Lower threshold for filtering (e.g., 0.1).
- ``algorithm.embodied_sampling.accuracy_upper_bound``: Upper threshold for filtering (e.g., 0.9).
Complete Training Script
~~~~~~~~~~~~~~~~~~~~~~~~
Below is an example script `run_embodied_srpo.sh` to run SRPO training on `libero_long`.
**Note**: The siiRL repository provides ready-to-use training scripts for all four LIBERO tasks in the `examples/embodied_srpo_trainer/` directory:
- ``run_openvla_oft_libero_long.sh``
- ``run_openvla_oft_libero_goal.sh``
- ``run_openvla_oft_libero_object.sh``
- ``run_openvla_oft_libero_spatial.sh``
To train on a specific task, modify the following paths in the script to match your actual environment:
- ``SIIRL_DIR``: Path to the siiRL repository
- ``VJEPA2_DIR``: Path to the V-JEPA2 repository (for ``PYTHONPATH``)
- ``HOME_PATH``: Your home directory or base path for models and data
- ``MODEL_PATH``: Path to the corresponding SFT model for the task
- ``VJEPA_MODEL_PATH``: Path to the V-JEPA 2 model weights file
**Note**: LIBERO is pre-installed in the Docker image at ``/root/LIBERO/`` and does not need to be modified.
.. code-block:: bash
#!/usr/bin/env bash
# ===================================================================================
# === Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-LONG ===
# ===================================================================================
#
set -e
# --- Environment Setup (Critical for siiRL) ---
export SIIRL_DIR="${SIIRL_DIR:-your_siirl_path}"
export PYTHONPATH="$SIIRL_DIR:/root/LIBERO/:${VJEPA2_DIR:-your_vjepa2_path}:$PYTHONPATH"
# --- Experiment and Model Definition ---
export DATASET=libero_long
export ALG=srpo
export MODEL_NAME=openvla-oft-7b
export MODEL_TYPE=openvla-oft
# --- Path Definitions (USER PROVIDED) ---
export HOME_PATH=${HOME_PATH:your_home_path}
export TRAIN_DATA_PATH=$HOME_PATH/data/train.parquet # generated automatically
export TEST_DATA_PATH=$HOME_PATH/data/test.parquet # generated automatically
export MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-long
export VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Embodied AI Specific Parameters ---
export ACTION_TOKEN_LEN=7 # 7 dimensions: xyz (3), quaternion (3), gripper (1)
export ACTION_CHUNKS_LEN=8 # OpenVLA-OFT uses 8-step action chunks
export NUM_ENVS=16 # actor_rollout_ref.embodied.env.num_envs
export MAX_EPISODE_STEPS=512 # actor_rollout_ref.embodied.env.max_steps
# --- Data and Sampling Parameters ---
export VAL_BATCH_SIZE=496 # Validation batch size
export MAX_PROMPT_LENGTH=256
export MAX_RESPONSE_LENGTH=128
# --- Embodied Sampling Parameters ---
export FILTER_ACCURACY=True # Enable accuracy-based filtering
export ACCURACY_LOWER_BOUND=0.1 # Only keep prompts with success rate >= 0.1
export ACCURACY_UPPER_BOUND=0.9 # Only keep prompts with success rate <= 0.9
export FILTER_TRUNCATED=False # Filter truncated episodes (uses env.max_steps)
export OVERSAMPLE_FACTOR=1 # Oversample factor for filtering
# --- Training Hyperparameters ---
export TRAIN_BATCH_SIZE=64 # data.train_batch_size
export PPO_MINI_BATCH_SIZE=4 # actor_rollout_ref.actor.ppo_mini_batch_size
# Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES
export ROLLOUT_N_SAMPLES=8 # REUSED: Number of samples per prompt
export PPO_EPOCHS=1 # actor_rollout_ref.actor.ppo_epochs
# Algorithm parameters
export LEARNING_RATE=5e-6
export WEIGHT_DECAY=0.0 # actor_rollout_ref.actor.optim.weight_decay
export CLIP_RATIO_HIGH=0.28 # actor_rollout_ref.actor.clip_ratio_high
export CLIP_RATIO_LOW=0.2 # actor_rollout_ref.actor.clip_ratio_low
export ENTROPY_COEFF=0.0
export TEMPERATURE=1.6
export GAMMA=1.0
export LAM=1.0
export GRAD_CLIP=1.0
# --- Image/Video Processing ---
export IMG_SIZE=384 # actor_rollout_ref.embodied.img_size
export ENABLE_FP16=True # actor_rollout_ref.embodied.enable_fp16
export EMBEDDING_MODEL_OFFLOAD=False # actor_rollout_ref.embodied.embedding_model_offload
export CENTER_CROP=True # actor_rollout_ref.embodied.center_crop
export NUM_IMAGES_IN_INPUT=1
export NUM_STEPS_WAIT=10 # Environment stabilization steps
# --- Trainer Configuration ---
export SAVE_FREQ=5
export TEST_FREQ=5
export TOTAL_EPOCHS=1000 # trainer.total_epochs
export MAX_CKPT_KEEP=5 # trainer.max_actor_ckpt_to_keep
export VAL_BEFORE_TRAIN=True # trainer.val_before_train
# --- Multi-node distributed training ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
export MASTER_PORT=${MASTER_PORT:-29500}
# --- Environment Variables ---
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export GLOO_SOCKET_TIMEOUT=600
# --- Output Paths and Experiment Naming ---
timestamp=$(date +%Y%m%d_%H%M%S)
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes
export PROJECT_NAME=siirl_embodied_${DATASET}
export EXPERIMENT_NAME=openvla_oft_srpo_fsdp
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}
# --- Define the Training Command ---
TRAINING_CMD=(
python3 -m siirl.client.main_dag
--config-name=embodied_srpo_trainer
# Data configuration
data.train_files=$TRAIN_DATA_PATH
data.val_files=$TEST_DATA_PATH
data.train_batch_size=$TRAIN_BATCH_SIZE
data.val_batch_size=$VAL_BATCH_SIZE
data.max_prompt_length=$MAX_PROMPT_LENGTH
data.max_response_length=$MAX_RESPONSE_LENGTH
# Algorithm configuration
algorithm.workflow_type=embodied
algorithm.adv_estimator=grpo
algorithm.gamma=$GAMMA
algorithm.lam=$LAM
algorithm.norm_adv_by_std_in_grpo=True
# Embodied sampling configuration (aligned with DAPO architecture)
algorithm.embodied_sampling.filter_accuracy=$FILTER_ACCURACY
algorithm.embodied_sampling.accuracy_lower_bound=$ACCURACY_LOWER_BOUND
algorithm.embodied_sampling.accuracy_upper_bound=$ACCURACY_UPPER_BOUND
algorithm.embodied_sampling.filter_truncated=$FILTER_TRUNCATED
algorithm.embodied_sampling.oversample_factor=$OVERSAMPLE_FACTOR
# Model configuration
actor_rollout_ref.model.path=$MODEL_PATH
actor_rollout_ref.model.enable_gradient_checkpointing=True
# Actor configuration
actor_rollout_ref.actor.optim.lr=$LEARNING_RATE
actor_rollout_ref.actor.optim.weight_decay=$WEIGHT_DECAY
actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_epochs=$PPO_EPOCHS
actor_rollout_ref.actor.grad_clip=$GRAD_CLIP
actor_rollout_ref.actor.clip_ratio_high=$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_low=$CLIP_RATIO_LOW
actor_rollout_ref.actor.entropy_coeff=$ENTROPY_COEFF
actor_rollout_ref.actor.shuffle=True
# Actor FSDP configuration
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.grad_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
# Rollout configuration
actor_rollout_ref.rollout.name=hf
actor_rollout_ref.rollout.n=$ROLLOUT_N_SAMPLES
actor_rollout_ref.rollout.temperature=$TEMPERATURE
actor_rollout_ref.rollout.do_sample=True
actor_rollout_ref.rollout.response_length=512
# Embodied AI specific configuration
actor_rollout_ref.embodied.embodied_type=$MODEL_TYPE
actor_rollout_ref.embodied.action_token_len=$ACTION_TOKEN_LEN
actor_rollout_ref.embodied.action_chunks_len=$ACTION_CHUNKS_LEN
actor_rollout_ref.embodied.video_embedding_model_path=$VJEPA_MODEL_PATH
actor_rollout_ref.embodied.embedding_img_size=$IMG_SIZE
actor_rollout_ref.embodied.embedding_enable_fp16=$ENABLE_FP16
actor_rollout_ref.embodied.embedding_model_offload=$EMBEDDING_MODEL_OFFLOAD
actor_rollout_ref.embodied.center_crop=$CENTER_CROP
actor_rollout_ref.embodied.num_images_in_input=$NUM_IMAGES_IN_INPUT
actor_rollout_ref.embodied.unnorm_key=$DATASET
# Environment configuration
actor_rollout_ref.embodied.env.env_type=libero
actor_rollout_ref.embodied.env.env_name=$DATASET
actor_rollout_ref.embodied.env.num_envs=$NUM_ENVS
actor_rollout_ref.embodied.env.max_steps=$MAX_EPISODE_STEPS
actor_rollout_ref.embodied.env.num_steps_wait=$NUM_STEPS_WAIT
actor_rollout_ref.embodied.env.num_trials_per_task=50
actor_rollout_ref.embodied.env.model_family=openvla
# Critic configuration (SRPO doesn't use critic)
critic.use_critic_model=False
# Trainer configuration
trainer.total_epochs=$TOTAL_EPOCHS
trainer.save_freq=$SAVE_FREQ
trainer.test_freq=$TEST_FREQ
trainer.max_actor_ckpt_to_keep=$MAX_CKPT_KEEP
trainer.logger=['console','tensorboard']
trainer.project_name=$PROJECT_NAME
trainer.experiment_name=$EXPERIMENT_NAME
trainer.nnodes=$NNODES
trainer.n_gpus_per_node=$N_GPUS_PER_NODE
trainer.default_local_dir=$CKPT_PATH
trainer.resume_mode=auto
trainer.val_before_train=$VAL_BEFORE_TRAIN
)
# ===================================================================================
# === EXECUTION LOGIC ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
Step 4: Checking the Results
----------------------------
1. **Logs**: Monitor the console output for training progress and environment interaction stats.
2. **TensorBoard**: Use TensorBoard to visualize rewards, success rates, and other metrics.
.. code:: bash
tensorboard --logdir ./tensorboard
3. **Checkpoints**: Trained models are saved in the ``ckpts`` directory.
================================================
FILE: docs/examples/megatron_backend_example.rst
================================================
Megatron-LM Training Backend
============================================
Introduction
------------
This guide explains how to use the Megatron-LM backend in siiRL for RL training. Megatron-LM is a powerful library for training very large transformer models, and integrating it as a backend allows for efficient 5D parallelism (DP/TP/EP/PP/CP).
This example demonstrates how to fine-tune a `Qwen3-8B` model using the GRPO algorithm with the Megatron-LM as training backend.
Step 1: Prepare the Dataset
---------------------------
First, ensure your dataset is in the required Parquet format. If you are using one of the example datasets like `gsm8k` or `deepscaler`, you can use the provided preprocessing scripts. For example, for `deepscaler`:
.. code:: bash
cd examples/data_preprocess
python3 deepscaler.py --local_dir ~/data/deepscaler
This will download and process the dataset, saving `train.parquet` and `test.parquet` in the specified directory.
Step 2: Download the Pre-trained Model
--------------------------------------
You need a base model to start training. For this example, we'll use `Qwen3-8B`. Download it from Hugging Face or ModelScope to a local directory.
.. code:: bash
# For Hugging Face
huggingface-cli download Qwen/Qwen3-8B-Instruct --local-dir ~/data/models/Qwen3-8B --local-dir-use-symlinks False
# For ModelScope
modelscope download Qwen/Qwen3-8B-Instruct --local_dir ~/data/models/Qwen3-8B
Step 3: Configure and Run the Training Script
---------------------------------------------
To use the Megatron-LM backend, you need to modify the training configuration in your run script.
Key Configuration Changes
~~~~~~~~~~~~~~~~~~~~~~~~~
The main change is to set the training strategy to `megatron` and configure its parallelism parameters.
1. **Set the Strategy**: e.g., in the `TRAINING_CMD` array, set `actor_rollout_ref.actor.strategy=megatron`.
2. **Configure Parallelism**: Add Megatron-specific settings for 5D parallelism. For a 8B model on a single node with 8 GPUs, you might use 2-way tensor parallelism and 4-way pipeline parallelism, with sequence parallelism enabled.
.. code-block:: text
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=4
actor_rollout_ref.actor.megatron.context_parallel_size=1
actor_rollout_ref.actor.megatron.sequence_parallel=True
3. **Configure Distributed Optimizer**: Add Megatron-specific settings for distributed optimizer. This allows for memory efficient training with ZeRO-1 optimization and is recommended for large models.
.. code-block:: text
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
4. **Configure Offloading**: Add Megatron-specific settings for parameter, gradient, and optimizer offload. This allows for parameter, gradient, and optimizer offloading to CPU to save GPU memory.
.. code-block:: text
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.grad_offload=True
actor_rollout_ref.actor.megatron.optimizer_offload=True
Complete Training Script
~~~~~~~~~~~~~~~~~~~~~~~~
Below is a complete example script, `run_qwen3-8b-megatron.sh`, which is adapted from the standard GRPO script to use the Megatron backend. You will need to create this script yourself or adapt an existing one.
.. code-block:: bash
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- For debugging
export HYDRA_FULL_ERROR=1
export SIIRL_LOG_VERBOSITY=INFO
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen3-8b
# --- Path Definitions ---
export HOME=${HOME:-"/root"} # Set your home path
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-8B
# Base output paths
export BASE_CKPT_PATH=$HOME/ckpts
export BASE_TENSORBOARD_PATH=$HOME/tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=128
export PPO_MINI_BATCH_SIZE_PER_NODE=16
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=1024
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.45
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# ---- Megatron Parallelism Configuration ----
export ACTOR_REF_TP=2
export ACTOR_REF_PP=4
export ACTOR_REF_CP=1
export ACTOR_REF_SP=True
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
timestamp=$(date +"%Y%m%d_%H%M%S")
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_megatron_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_megatron_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_megatron_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_megatron_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.model.enable_gradient_checkpointing=True
# --- Megatron Backend Configuration ---
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=False
# --- PPO & Other Hyperparameters ---
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.grad_clip=1.0
# --- Rollout (vLLM) Configuration ---
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
# --- Trainer Configuration ---
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
Step 4: Checking the Results
----------------------------
During training, you can monitor the progress through several means:
1. **Console Logs**: The console will output detailed logs. Look for initialization messages from the Megatron backend to confirm it's being used. You should see logs pertaining to the setup of 5D parallelism.
2. **TensorBoard**: If you enabled the `tensorboard` logger, you can monitor training metrics in real-time.
.. code:: bash
tensorboard --logdir $HOME/tensorboard
Navigate to the TensorBoard URL in your browser to view metrics such as reward, KL divergence, and loss curves.
3. **Checkpoints**: Checkpoints will be saved in the directory specified by `CKPT_PATH`. You can use these to resume training or for inference later.
================================================
FILE: docs/examples/mm_eureka_example.rst
================================================
MM-Eureka Example with GRPO
===========================
Introduction
------------
This guide details how to fine-tune a multi-modal Large Language Model using the **Group Relative Policy Optimization (GRPO)** algorithm on the **MM-Eureka** dataset. MM-Eureka is a challenging dataset designed to test mathematical reasoning that requires interpreting both text and images.
**Paper:** https://arxiv.org/pdf/2503.07365.
**Dataset:** https://huggingface.co/datasets/FanqingM/MM-Eureka-Dataset
The goal is to enhance a model's ability to perform complex reasoning by processing visual and textual information simultaneously. We use GRPO, an advanced RL algorithm, to optimize the model's policy.
Dataset Overview
----------------
MM-Eureka problems consist of a text-based question paired with one or more images. The model must understand the content of the image to solve the problem correctly.
**An example from MM-Eureka:**
**Prompt:**
.. image:: https://github.com/sii-research/siiRL/raw/main/docs/_static/cube.jpg
:width: 50%
Question: A cube loses one vertex after a 'corner' is removed. This geometric shape is ___ (fill in the number).
**Answer:**
3
Step 1: Data Preprocessing
--------------------------
The raw MM-Eureka dataset, typically in `.jsonl` format, must be converted to Parquet. This involves not only structuring the text but also processing the associated images.
The script `examples/data_preprocess/mm_eureka.py` handles this. It performs the following actions:
- Parses each line of the input JSONL file.
- Reads the image file specified in `image_urls` and embeds its byte content directly into the Parquet file.
- Formats the user prompts to include instructions for the desired output structure (`......`).
- Splits the data into training and testing sets.
Run the script with your dataset file:
.. code:: bash
cd examples/data_preprocess
python3 mm_eureka.py --jsonl_file /path/to/your/mm_eureka_data.jsonl --output_dir ~/data/mm_eureka/
Step 2: Defining the Reward Score
---------------------------------
A custom reward function is crucial for multi-modal reasoning. For MM-Eureka, we use a composite score defined in `siirl/utils/reward_score/mm_eureka.py`. This function evaluates two aspects of the model's response:
1. **Accuracy Reward**: This is the primary component. It parses the mathematical expression from the model's output (often in LaTeX) and compares it against the ground truth using the `math_verify` utility. This provides a robust check for mathematical correctness.
2. **Format Reward**: A smaller, secondary reward is given if the model correctly follows the required `......` structure. This encourages the model to generate well-formed, interpretable reasoning chains.
The final reward is a weighted sum of these two components (e.g., `0.9 * accuracy_reward + 0.1 * format_reward`), balancing correctness with style.
Step 3: Download the Pre-trained Model
--------------------------------------
For this multi-modal task, we use a powerful vision-language model like `Qwen2.5-VL-7B-Instruct`. Ensure the model is available locally for the training script.
- **Recommended: Download via CLI:**
.. code:: bash
# For Hugging Face
huggingface-cli download Qwen/Qwen2.5-VL-7B-Instruct --local-dir ~/data/models/Qwen2.5-VL-7B-Instruct
# For ModelScope
modelscope download Qwen/Qwen2.5-VL-7B-Instruct --local_dir ~/data/models/Qwen2.5-VL-7B-Instruct
- **Automatic Download:** Alternatively, specify the model identifier directly in the run script's `actor_rollout_ref.model.path` field.
Step 4: Perform GRPO Training
-----------------------------
With the data and model prepared, you can launch the training job using the GRPO algorithm.
**Training Script**
The script `examples/grpo_trainer/run_qwen2_5_vl-7b.sh` provides a complete configuration for this task. It sets up the environment, Ray cluster, and all necessary hyperparameters for GRPO training on the MM-Eureka dataset. Adapt the `HOME` path and other variables as needed for your environment.
.. literalinclude:: ../../examples/grpo_trainer/run_qwen2_5_vl-7b.sh
:language: bash
:caption: examples/grpo_trainer/run_qwen2_5_vl-7b.sh
================================================
FILE: docs/hardware_tutorial/ascend_profiling_en.rst
================================================
Data Collection on Ascend Devices Based on the FSDP Backend
============================================================
Last updated: 08/14/2025.
This is a tutorial for using GRPO to collect data on Ascend devices based on the FSDP backend.
Configuration
-------------
- Global Collection Control: Use the configuration items in siirl/client/config/ppo_trainer.yaml to control the default collection mode.
Control collection parameters using parameters in ppo_trainer.yaml:
- enable: Whether to enable performance profiling.
- save_path: The path to save collected data.
- level: Collection level—options include level_none, level0, level1, and level2.
- level_none: Disables all level-based data collection (turns off profiler_level).
- level0: Collects high-level application data, low-level NPU data, and operator execution details on the NPU.
- level1: Adds CANN layer AscendCL data and AI Core performance metrics on the NPU based on level0.
- level2: Adds CANN layer Runtime data and AI CPU metrics based on level1.
- with memory: Enables memory analysis (defaults to True).
- record shapes: Enables recording of tensor shapes (defaults to False).
- with npu: Enables collection of device-side performance data (defaults to True).
- with cpu: Enables collection of host-side performance data (defaults to True).
- with module: Enables recording of framework-level Python call stack information.
- with stack: Enables recording of operator call stack information.
- analysis: Enables automatic data analysis.
- discrete: Enables discrete mode, collecting performance data for each stage separately (defaults to False).
- roles: Collection stage - used in conjunction with the discrete parameter. Options include:
generate, compute_reward, compute_old_log_prob, compute_ref_log_prob, compute_value, compute_advantage,
train_critic, train_actor
- all_ranks: Whether to collect data from all ranks.
- ranks: List of ranks for which to collect data. If empty, no data is collected.
- profile_steps: List of collection steps. For example, [2, 4] indicates that steps 2 and 4 will be collected. If set to null, no data is collected.
Example
-------
Disable collection
~~~~~~~~~~~~~~~~~~~~
.. code:: yaml
profiler:
enable: False # disable profile
End-to-end collection
~~~~~~~~~~~~~~~~~~~~~
.. code:: yaml
profiler:
steps: [1, 2, 5]
discrete: False
The run_qwen2_5-7b-npu-e2e_prof.sh script is provided in examples/grpo_trainer for reference.
Discrete mode collection
~~~~~~~~~~~~~~~~~~~~~~~~
.. code:: yaml
profiler:
discrete: True
roles:['generate', 'train_actor']
The discrete mode acquisition script run_qwen2_5-7b-npu-discrete_prof.sh is provided in examples/grpo_trainer for reference.
Visualization
-------------
The acquired data is stored in the user-defined save_path and can be visualized using the MindStudio Insight tool,
you can refer to .
If the analysis parameter is set to False, offline analysis is required after collection:
.. code:: python
import argparse
from torch_npu.profiler.profiler import analyse
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, default="facebook/opt-125m")
if __name__ == "__main__":
args = parser.parse_args()
path = args.path
================================================
FILE: docs/hardware_tutorial/ascend_quickstart.rst
================================================
Ascend NPU
==========
SiiRL is also supports for Huawei's Ascend NPU devices. This guide has been tested with the following hardware:
- Atlas 200T A2 Box16
Installation Process
--------------------
Core Environment Requirements
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Ensure your environment meets these core software version requirements:
+---------------------+------------+
| Software | Version |
+---------------------+------------+
| Python | == 3.10 |
+---------------------+------------+
| CANN | == 8.1.RC1 |
+---------------------+------------+
| PyTorch | == 2.5.1 |
+---------------------+------------+
| torch_npu | == 2.5.1 |
+---------------------+------------+
| mindspeed(Optional) | == 0.12.1 |
+---------------------+------------+
Recommended Base Image
^^^^^^^^^^^^^^^^^^^^^^
For a smoother setup, we strongly recommend using our pre-built Docker image, which includes all necessary dependencies. Please note this pre-built docker image contains torch, torch-npu, vLLM and vLLM-Ascend packages, after pulling it you only need to install siiRL framework from source.
.. code-block:: bash
docker pull crispig/verl_npu:cann8.1rc1-py3.10-torch2.5.1-vllm-ascend0.7.3.post1-250616
Compiling vLLM and vllm-ascend [Optional]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Proper integration of vLLM within siiRL requires compiling both `vllm` and `vllm-ascend` from source. Follow the steps below, paying close attention to the instructions specific to your hardware.
.. note::
We recommend using the latest version of vllm v0.9.2 and vllm-ascend v0.9.0rc2, which support setting use_remove_padding=True.
.. code-block:: bash
# vllm
git clone -b v0.9.2 --depth 1 https://github.com/vllm-project/vllm.git
cd vllm
pip install -r requirements-build.txt
# For Atlas 200T A2 Box16
VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/
.. code-block:: bash
# vllm-ascend
git clone -b v0.9.0rc2 --depth 1 https://github.com/vllm-project/vllm-ascend.git
cd vllm-ascend
export COMPILE_CUSTOM_KERNELS=1
python setup.py install
SiiRL Installation
^^^^^^^^^^^^^^^^^^
Finally, install the siiRL framework itself. DO NOT use the pip install command to install siiRL, it will cause dependency conflicts.
.. code-block:: bash
git clone https://github.com/sii-research/siiRL.git
cd siirl
pip install -e .
Third-Party Library Considerations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Please be aware of the following specific requirements and limitations for certain libraries on Ascend hardware:
+--------------+---------------+
| Software | Description |
+--------------+---------------+
| transformers | v4.52.4 |
+--------------+---------------+
| flash_attn | not supported |
+--------------+---------------+
| liger-kernel | not supported |
+--------------+---------------+
| tensordict | 0.8.3 (ARM) |
+--------------+---------------+
1. Using `--flash_attention_2` through `transformers` is supported (requires `transformers` version >= 4.52.0).
2. Flash Attention acceleration via the `flash_attn` package is not supported.
3. `liger-kernel` is not supported.
4. For ARM servers, `tensordict` version 0.8.3 is required. You can manually install it after the main dependencies are installed.
5. For x86 servers, the CPU version of `torchvision` must be installed.
.. code-block:: bash
pip install torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu
Verification with a Quick Start Example
---------------------------------------
To ensure your setup is correct, we recommend performing a quick test run. The following example trains a Qwen2.5-0.5B model on the GSM8k dataset using the GRPO algorithm.
1. **Prepare the Dataset**
First, download and preprocess the GSM8k dataset. The provided script will convert it to the Parquet format required by the framework.
.. code-block:: bash
python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k
2. **Run the Training Job**
Next, execute the training command below. Ensure you have set the `VLLM_ATTENTION_BACKEND` environment variable.
.. code-block:: bash
set -x
python3 -m siirl.main_dag \
algorithm.adv_estimator=grpo \
data.train_files=/datasets/gsm8k/train.parquet\
data.val_files=/datasets/gsm8k/teset.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=/models/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.actor.optim.lr=5e-8 \
actor_rollout_ref.model.use_remove_padding=False \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='siirl_grpo_example_gsm8k' \
trainer.experiment_name='qwen2_05b_function_rm' \
trainer.n_gpus_per_node=16 \
trainer.nnodes=$NNODES \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=300 \
trainer.device=npu $@
(Optional) Setting Up MindSpeed Training Backend Guide
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Refer to the MindSpeed README _ for instructions on installing the MindSpeed acceleration library, recommended versions: MindSpeed Core 0.12.1, Megatron-LM 0.12.2.
.. warning::
Please Be sure to install **megatron-core** via ``pip install``.
Using ``PYTHONPATH`` to point to megatron will crash the program.
Enable siirl worker model ``strategy`` and set it to ``megatron``. For example: ``actor_rollout_ref.actor.strategy=megatron``.
Custom MindSpeed parameters can be passed through the override_transformer_config option. For instance, to enable FA for the actor model, you can use:
``+actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True``.
MindSpeed provides the same support for siiRL and verl. For more feature details, please refer to the MindSpeed+verl documentation. _.
================================================
FILE: docs/hardware_tutorial/metax_quickstart.rst
================================================
MetaX(沐曦) GPU
===============
SiiRL is also supports for MetaX's GPU devices. This guide has been tested with the following hardware:
- 曦云 series GPU
Installation Process
--------------------
Recommended Base Image
^^^^^^^^^^^^^^^^^^^^^^
For a smoother setup, we strongly recommend using our pre-built Docker image, which includes all necessary dependencies. Please refer to MetaX developer website: https://developer.metax-tech.com/softnova/docker, after pulling it you only need to install siiRL framework from source.
.. code-block:: bash
docker pull siiai/siirl-metax:maca.ai3.1.0.1-torch2.6-py310-ubuntu22.04-amd64
Start docker container
^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
docker run -d -t --net=host --uts=host --ipc=host --privileged=true --group-add video \
--shm-size 100gb --ulimit memlock=-1 --security-opt seccomp=unconfined \
--security-opt apparmor=unconfined --device=/dev/dri --device=/dev/mxcd --device=/dev/infiniband \
-v /data/:/data/ \
--name siirl \
siiai/siirl-metax:maca.ai3.1.0.1-torch2.6-py310-ubuntu22.04-amd64 bash
SiiRL Installation
^^^^^^^^^^^^^^^^^^
Finally, install the siiRL framework itself. DO NOT use the pip install command to install siiRL, it will cause dependency conflicts.
.. code-block:: bash
git clone https://github.com/sii-research/siiRL.git
cd siirl
# You need to comment out the libraries adapted for MetaX, such as ray and vllm, to prevent them from being overwritten.
# vllm>=0.8.5.post1
# ray[default]>=2.47.1
pip install -r requirements.txt
pip install -e .
Add environment variables for MetaX
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
# mx gpu env
export MACA_PATH=/opt/maca
export CUCC_PATH=${MACA_PATH}/tools/cu-bridge
export CUDA_PATH=${CUCC_PATH}
export MACA_CLANG_PATH=$MACA_PATH/mxgpu_llvm/bin
export PATH=${CUDA_PATH}/bin:${MACA_CLANG_PATH}:${PATH}
export LD_LIBRARY_PATH=${MACA_PATH}/tools/cu-bridge/lib/:${MACA_PATH}/lib:${MACA_PATH}/ompi/lib:${MACA_PATH}/mxgpu_llvm/lib:${LD_LIBRARY_PATH}
export PYTORCH_ENABLE_SAME_RAND_A100=1
export MCPYTORCH_DISABLE_PRINT=1
export MAX_JOBS=20
export VLLM_USE_V1=0
export MCCL_ENABLE_FC=0
export MCCL_MAX_NCHANNELS=8
export PYTHONUNBUFFERED=1
export MCCL_IB_HCA=mlx5
export MCCL_SOCKET_IFNAME=ens1
export GLOO_SOCKET_IFNAME=ens1
export SOCKET_NIC=ens1
Verification with a Quick Start Example
---------------------------------------
To ensure your setup is correct, we recommend performing a quick test run. The following example trains a Qwen2.5-0.5B model on the GSM8k dataset using the GRPO algorithm.
1. **Prepare the Dataset**
First, download and preprocess the GSM8k dataset. The provided script will convert it to the Parquet format required by the framework.
.. code-block:: bash
python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k
2. **Run the Training Job**
Next, execute the training command below. Ensure you have set the `VLLM_ATTENTION_BACKEND` environment variable.
.. code-block:: bash
# --- Experiment and Model Definition ---
export DATASET=gsm8k
export ALG=grpo
export MODEL_NAME=qwen2.5-05b
# --- Path Definitions ---
export HOME=/data/
export TRAIN_DATA_PATH=$HOME/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/$DATASET/test.parquet
export MODEL_PATH=$HOME/Qwen2.5-0.5B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=1024
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.4
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# mx gpu env
export MACA_PATH=/opt/maca
export CUCC_PATH=${MACA_PATH}/tools/cu-bridge
export CUDA_PATH=${CUCC_PATH}
export MACA_CLANG_PATH=$MACA_PATH/mxgpu_llvm/bin
export PATH=${CUDA_PATH}/bin:${MACA_CLANG_PATH}:${PATH}
export LD_LIBRARY_PATH=${MACA_PATH}/tools/cu-bridge/lib/:${MACA_PATH}/lib:${MACA_PATH}/ompi/lib:${MACA_PATH}/mxgpu_llvm/lib:${LD_LIBRARY_PATH}
export PYTORCH_ENABLE_SAME_RAND_A100=1
export MCPYTORCH_DISABLE_PRINT=1
export MAX_JOBS=20
export VLLM_USE_V1=0
export MCCL_ENABLE_FC=0
export MCCL_MAX_NCHANNELS=8
export PYTHONUNBUFFERED=1
export MCCL_IB_HCA=mlx5
export MCCL_SOCKET_IFNAME=ens1
export GLOO_SOCKET_IFNAME=ens1
export SOCKET_NIC=ens1
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=True
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=False
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
# export VLLM_USE_V1=0
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
!/usr/bin/env bash
================================================
FILE: docs/index.rst
================================================
.. siiRL documentation master file, created by
sphinx-quickstart on Wed Jul 9 15:26:45 2025.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
siiRL documentation
===================
.. toctree::
:maxdepth: 2
:caption: Quickstart
start/install
start/quickstart
.. toctree::
:maxdepth: 2
:caption: Programming guide
programming_guide/siirl_architecture_guide
programming_guide/code_structure
programming_guide/siiRL_code_explained
programming_guide/srpo_code_explained
.. toctree::
:maxdepth: 1
:caption: Data Preparation
preparation/prepare_data
preparation/reward_function
.. toctree::
:maxdepth: 2
:caption: User Define Interface
user_interface/filter_interface
user_interface/reward_interface
user_interface/pipeline_interface
user_interface/metrics_interface
.. toctree::
:maxdepth: 2
:caption: Configurations
examples/config
.. toctree::
:maxdepth: 1
:caption: Example
examples/deepscaler_example
examples/mm_eureka_example
examples/cpgd_example
examples/megatron_backend_example
examples/embodied_srpo_example
.. toctree::
:maxdepth: 1
:caption: Hardware Support
hardware_tutorial/ascend_quickstart
hardware_tutorial/ascend_profiling_en
hardware_tutorial/metax_quickstart
================================================
FILE: docs/preparation/prepare_data.rst
================================================
Prepare Data for Post-Training
========================================
Before starting the post-training job, we need to prepare the data for policy training. The data should be preprocessed and stored in Parquet format, which facilitates efficient distributed data loading and processing.
We provide several data preprocessing scripts for popular datasets under the ``examples/data_preprocess/`` directory, such as ``gsm8k.py``, ``math_dataset.py``, and ``deepscaler.py``. To support a new custom dataset, you will need to create a similar script.
This document uses the ``DeepScaleR`` dataset as an example to detail the data preparation process and its specifications.
General Data Preprocessing Workflow
-----------------------------------
A typical data preprocessing script involves the following steps:
1. **Load Raw Data**: Use a library like Hugging Face's ``datasets`` to load the original dataset from the Hub or local files.
2. **Define Processing Logic**: Implement a core mapping function (which we often name ``make_map_fn``) to convert each sample from the original dataset into the specific format required by our framework.
3. **Apply Transformation and Save**: Use the ``datasets.map()`` method to apply this function to the entire dataset. Then, save the processed result in Parquet format locally, with an option to upload it to a distributed file system like HDFS.
Here is a simplified framework of the process:
.. code:: python
import argparse
import os
import datasets
from siirl.utils.extras.hdfs_io import copy, makedirs
def make_map_fn(split_name):
# ... Define your data processing logic here ...
def process_fn(example, idx):
# ... Transform each data sample ...
return transformed_data
return process_fn
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# ... Define arguments ...
args = parser.parse_args()
# 1. Load data
raw_dataset = datasets.load_dataset(...)
# 2. Apply transformation
processed_dataset = raw_dataset.map(function=make_map_fn('train'), with_indices=True)
# 3. Save as Parquet
local_dir = args.local_dir
processed_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
# (Optional) Upload to HDFS
if args.hdfs_dir:
makedirs(args.hdfs_dir)
copy(src=local_dir, dst=args.hdfs_dir)
DeepScaleR Dataset Processing in Practice
-------------------------------------------
Let's take ``examples/data_preprocess/deepscaler.py`` as a concrete example to demonstrate how to process the ``agentica-org/DeepScaleR-Preview-Dataset``.
The core task is to implement the ``make_map_fn`` function, which maps original fields (like ``problem``, ``answer``, and ``solution``) to the standard format required by the training framework.
.. code:: python
data_source = "agentica-org/DeepScaleR-Preview-Dataset"
instruction_following = 'Let\'s think step by step and output the final within \\boxed{}.'
def make_map_fn(split_name):
def process_fn(example, idx):
question_raw = example.pop("problem")
answer_raw = example.pop("answer")
question = question_raw + " " + instruction_following
solution = example.pop("solution")
data = {
"data_source": data_source,
"prompt": [
{
"role": "user",
"content": question,
}
],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": answer_raw},
"extra_info": {
"split": split_name,
"index": idx,
"answer": solution,
"question": question_raw,
},
}
return data
return process_fn
Data Format Specification
-------------------------
To ensure the framework can correctly parse and utilize the data, each sample processed by ``make_map_fn`` must contain the following five key fields:
1. ``data_source``: A string indicating the source or name of the dataset. This field is used to dynamically select the corresponding reward function during training.
- Example: ``"agentica-org/DeepScaleR-Preview-Dataset"``
2. ``prompt``: A list used to construct the model's input, formatted to be compatible with Hugging Face's Chat Template. The data loader will automatically apply the template and tokenize the input.
- Example: ``[{"role": "user", "content": "What is 2+2? Let's think step by step..."}]``
3. ``ability``: A string defining the domain or capability of the current task, such as ``"math"``, ``"coding"``, or ``"general"``.
4. ``reward_model``: A dictionary containing information needed to calculate the reward. Currently, the ``ground_truth`` field is primarily used during evaluation.
- **Note**: The ``ground_truth`` you provide must align with the logic of the corresponding reward function you implement. For a math problem, it might be the standard answer; for code generation, it could be a set of unit tests.
- Example: ``{"style": "rule", "ground_truth": "\\boxed{4}"}``
5. ``extra_info``: A dictionary for storing additional metadata, such as the original dataset split (train/test) or sample index. This information is not used directly in training but is useful for debugging and data traceability.
By following these specifications, you can prepare your dataset to be used smoothly within the SiiRL post-training pipeline.
================================================
FILE: docs/preparation/reward_function.rst
================================================
Implementing Reward Functions for Datasets
===========================================
In Reinforcement Learning for LLMs, the reward function is a critical component that guides the model's learning process. It quantitatively evaluates the quality of a generated response, signaling what constitutes a "good" or "bad" output. Our framework provides a flexible system for defining these rewards, supporting both pre-implemented logic for common datasets and fully customized functions for specific tasks.
The RewardManager
-----------------
The ``RewardManager`` is the central hub for reward computation. As defined in `siirl/scheduler/reward.py`, its primary role is to orchestrate the scoring of generated responses by invoking a specified scoring function. Different managers, like `NaiveRewardManager` or `BatchRewardManager`, offer different strategies for handling this process. This design is consistent with the `verl` framework's architecture. [1]_
The typical workflow is as follows:
1. The manager receives a `DataProto` object, which is a batch containing all necessary information.
2. It extracts relevant fields, such as the model's generated text (`solution_strs`) and the reference answer (`ground_truth`).
3. It passes this data to a designated scoring function (`compute_score_fn`) to calculate the reward for each item in the batch.
This design allows the core training loop to remain agnostic to the specifics of reward calculation, which are neatly encapsulated within the manager and its scoring function.
Reward Function Implementations
-------------------------------
You can define reward logic in two ways: by using our pre-built functions or by creating your own.
Pre-implemented Functions
~~~~~~~~~~~~~~~~~~~~~~~~~
For standard benchmarks, we provide ready-to-use reward functions in the `siirl/utils/reward_score/` directory. These cover datasets like `GSM8K` and `MATH`, implementing their standard evaluation logic. For instance, the `GSM8K` scorer extracts the final numerical answer and compares it to the ground truth.
Customized Functions
~~~~~~~~~~~~~~~~~~~~
For novel tasks or custom evaluation criteria, you can supply your own reward function. This is configured via two parameters: `custom_reward_function.path` and `custom_reward_function.name`.
Let's consider a practical example from the `run_qwen2_5-7b-custom_reward.sh` script, which uses a batch-processing reward function for efficiency.
**1. Configuration in the script:**
The script specifies the path to the custom code, the function to use, and selects the `BatchRewardManager` to execute it.
.. code-block:: bash
# ... other configurations ...
python3 -m siirl.main_dag \
# ...
custom_reward_function.path=$HOME/rl/rewardfunc_gsm8k.py \
custom_reward_function.name=compute_score \
reward_model.reward_manager=batch \
# ...
**2. Implementation of the reward function:**
The corresponding `rewardfunc_gsm8k.py` file implements the `compute_score` function. This function is designed to process an entire batch of solutions at once, which is significantly more efficient than processing them one by one.
.. code:: python
import re
def extract_solution(solution_str, method="strict"):
# ... (logic to extract the final answer from text)
# For example, finds the number after "####"
if method == "strict":
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
if solution is None: return None
final_answer = solution.group(0).split("#### ")[1].replace(",", "")
return final_answer
# ... other extraction logic ...
def compute_score(data_sources, solution_strs, ground_truths, extra_infos, method="strict", score=1.0, **kwargs):
"""
Computes scores for a batch of solutions.
"""
scores = []
for solution_str, ground_truth in zip(solution_strs, ground_truths):
answer = extract_solution(solution_str=solution_str, method=method)
if answer is not None and answer == ground_truth:
scores.append(score)
else:
scores.append(0.0)
return scores
The function signature should accept lists of `solution_strs` and `ground_truths`. You can also pass custom parameters from your configuration, like `method` or `score`, by defining them under `custom_reward_function.reward_kwargs`. This allows you to easily experiment with different reward schemes without changing the code.
.. [1] https://verl.readthedocs.io/en/latest/preparation/reward_function.html
================================================
FILE: docs/programming_guide/code_structure.rst
================================================
===============
Code Structure
===============
This document describes the code structure and architecture of siiRL.
Directory Structure
-------------------
.. code-block:: text
siirl/
├── main_dag.py # Main entry point
├── dag_worker/ # DAG Worker implementation
├── execution/ # Execution engine
├── engine/ # Model engine
├── data_coordinator/ # Data coordination
├── params/ # Configuration parameters
├── environment/ # Environment abstraction
└── user_interface/ # User interface
Core Modules
------------
dag_worker/
~~~~~~~~~~~
DAG execution unit, one worker per GPU.
.. code-block:: text
dag_worker/
├── dagworker.py # Core Worker class (~1320 lines)
├── core_algos.py # RL algorithm implementations
├── dag_utils.py # Utility functions
├── checkpoint_manager.py # Checkpoint management
├── metrics_collector.py # Metrics collection
├── metric_aggregator.py # Metrics aggregation
├── validator.py # Validation logic
├── constants.py # Constants
└── data_structures.py # Data structures
**Responsibilities:**
- Execute TaskGraph nodes
- Manage model Workers (Actor/Critic/Rollout/Reference/Reward)
- Data flow and caching
- Metrics collection and reporting
- Checkpoint saving and loading
execution/
~~~~~~~~~~
Execution engine for DAG definition, scheduling, and metrics aggregation.
.. code-block:: text
execution/
├── dag/ # DAG definition
│ ├── task_graph.py # TaskGraph class
│ ├── node.py # Node class
│ ├── builtin_pipelines.py # Built-in Pipelines
│ ├── pipeline.py # Pipeline Builder API
│ ├── config_loader.py # Configuration loader
│ └── task_loader.py # Task loader
├── scheduler/ # Task scheduling
│ ├── task_scheduler.py # Task scheduler
│ ├── launch.py # Ray launcher
│ ├── process_group_manager.py # Process group manager
│ ├── graph_updater.py # Graph updater
│ ├── reward.py # Reward scheduler
│ ├── enums.py # Enum definitions
│ └── resource_manager.py # Resource manager
├── metric_worker/ # Distributed metrics aggregation
│ ├── metric_worker.py # MetricWorker
│ └── utils.py
└── rollout_flow/ # Rollout flow
├── multi_agent/ # Multi-agent support
└── multiturn/ # Multi-turn interaction
**Responsibilities:**
- DAG definition and validation
- Task scheduling and resource allocation
- Distributed metrics collection
- Multi-agent/multi-turn interaction flow
engine/
~~~~~~~
Model execution engine containing all model workers.
.. code-block:: text
engine/
├── actor/ # Actor models
│ ├── base.py
│ ├── dp_actor.py # FSDP Actor
│ ├── megatron_actor.py # Megatron Actor
│ └── embodied_actor.py # Embodied Actor
├── critic/ # Critic models
│ ├── base.py
│ ├── dp_critic.py
│ └── megatron_critic.py
├── rollout/ # Rollout engine
│ ├── base.py
│ ├── vllm_rollout/ # vLLM backend
│ ├── sglang_rollout/ # SGLang backend
│ ├── hf_rollout.py # HuggingFace backend
│ └── embodied_rollout.py # Embodied Rollout
├── reward_model/ # Reward models
├── reward_manager/ # Reward managers
│ ├── naive.py # Simple reward
│ ├── batch.py # Batch Reward Model
│ ├── parallel.py # Parallel Reward Model
│ ├── dapo.py # DAPO Reward
│ └── embodied.py # Embodied Reward
├── sharding_manager/ # Sharding management
├── base_worker/ # Worker base classes
├── fsdp_workers.py # FSDP Worker
└── megatron_workers.py # Megatron Worker
**Responsibilities:**
- Training and inference for Actor/Critic/Rollout/Reference/Reward models
- Support for FSDP and Megatron backends
- Support for vLLM/SGLang/HuggingFace inference backends
data_coordinator/
~~~~~~~~~~~~~~~~~
Data coordinator for distributed data management.
.. code-block:: text
data_coordinator/
├── data_buffer.py # Distributed data buffer
├── dataloader/ # Data loading
│ ├── data_loader_node.py
│ ├── partitioned_dataset.py
│ ├── embodied_preprocess.py
│ └── vision_utils.py
├── protocol.py # Data protocol
└── sample.py # Sampling logic
**Responsibilities:**
- Distributed data buffering (per-server)
- Data loading (per-GPU)
- Data redistribution and load balancing
params/
~~~~~~~
Parameter configuration using Hydra.
.. code-block:: text
params/
├── __init__.py # SiiRLArguments
├── parser.py # Configuration parser
├── data_args.py # Data parameters
├── model_args.py # Model parameters
├── training_args.py # Training parameters
├── dag_args.py # DAG parameters
├── embodied_args.py # Embodied parameters
└── profiler_args.py # Profiler parameters
environment/
~~~~~~~~~~~~
Environment abstraction for Embodied AI and multi-agent systems.
.. code-block:: text
environment/
└── embodied/
├── base.py # Environment base class
├── venv.py # Vectorized environment
└── adapters/ # Environment adapters
└── libero.py # Libero adapter
user_interface/
~~~~~~~~~~~~~~~
User-defined interfaces.
.. code-block:: text
user_interface/
├── filter_interface/
│ ├── dapo.py # DAPO dynamic sampling
│ └── embodied.py # Embodied data filtering
└── rewards_interface/
└── custom_gsm8k_reward.py # Custom reward example
**Purpose:** Provides interfaces for user-defined node functions.
Data Structures
---------------
NodeOutput
~~~~~~~~~~
Return value from node execution.
.. code-block:: python
@dataclass
class NodeOutput:
batch: Any # Data batch
metrics: Dict = None # Metrics
info: Dict = None # Additional info
Node
~~~~
DAG node definition.
.. code-block:: python
@dataclass
class Node:
node_id: str # Node ID
node_type: NodeType # Node type
node_role: NodeRole # Node role
dependencies: List[str] # Dependency nodes
executable: Callable # Executable function
executable_ref: str # Function path
only_forward_compute: bool # Forward only
Enumerations
~~~~~~~~~~~~
**NodeType:**
.. code-block:: python
class NodeType(Enum):
MODEL_INFERENCE = "model_inference"
MODEL_TRAIN = "model_train"
COMPUTE = "compute"
DATA_LOAD = "data_load"
**NodeRole:**
.. code-block:: python
class NodeRole(Enum):
ROLLOUT = "rollout"
ACTOR = "actor"
CRITIC = "critic"
REFERENCE = "reference"
REWARD = "reward"
ADVANTAGE = "advantage"
DYNAMIC_SAMPLING = "dynamic_sampling"
DEFAULT = "default"
**AdvantageEstimator:**
.. code-block:: python
class AdvantageEstimator(Enum):
GRPO = "grpo"
GAE = "gae"
CPGD = "cpgd"
GSPO = "gspo"
**WorkflowType:**
.. code-block:: python
class WorkflowType(Enum):
DEFAULT = "DEFAULT"
DAPO = "DAPO"
EMBODIED = "EMBODIED"
Execution Flow
--------------
Startup Flow (main_dag.py)
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: text
1. Parse configuration (parse_config)
2. Load Pipeline (load_pipeline)
3. Initialize DataBuffer (init_data_coordinator)
4. Initialize MetricWorker
5. Task scheduling (TaskScheduler)
6. Launch Ray cluster (RayTrainer)
7. Create DAGWorker (one per GPU)
8. Execute training (DAGWorker.execute_task_graph)
DAGWorker Execution Flow
~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: text
1. Initialize Workers (Actor/Critic/Rollout/Reference/Reward)
2. Initialize DataLoader
3. Initialize Validator
4. Load Checkpoint (if exists)
5. Training loop:
- Load data
- Execute nodes in topological order
- Collect metrics
- Save Checkpoint
- Validate (if needed)
Node Execution Flow
~~~~~~~~~~~~~~~~~~~
.. code-block:: text
1. DAGWorker gets node's executable function
2. Call function with current batch
3. Function processes data, returns NodeOutput
4. Update batch, pass to next node
5. Collect node metrics
Key Concepts
------------
TaskGraph
~~~~~~~~~
Directed Acyclic Graph representing training workflow.
**Core Methods:**
- ``add_node()``: Add node
- ``build_adjacency_lists()``: Build adjacency lists
- ``validate_graph()``: Validate DAG
- ``get_execution_order()``: Get topological sort
Pipeline
~~~~~~~~
Declarative API for building TaskGraph.
**Core Methods:**
- ``add_node()``: Add node (supports chaining)
- ``build()``: Build and validate TaskGraph
DAGWorker Class
~~~~~~~~~~~~~~~
Execution unit per GPU.
**Core Methods:**
- ``generate()``: Rollout generation
- ``compute_reward()``: Compute reward
- ``compute_advantage()``: Compute advantage
- ``compute_old_log_prob()``: Old policy log prob
- ``compute_ref_log_prob()``: Reference model log prob
- ``compute_value()``: Value function (PPO)
- ``train_actor()``: Train actor
- ``train_critic()``: Train critic (PPO)
Configuration Parameters
------------------------
Main Configuration Groups
~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: yaml
algorithm:
adv_estimator: grpo # grpo/gae/cpgd/gspo
workflow_type: DEFAULT # DEFAULT/DAPO/EMBODIED
data:
train_files: /path/to/train.parquet
train_batch_size: 512
max_prompt_length: 2048
max_response_length: 4096
actor_rollout_ref:
model:
path: /path/to/model
actor:
optim:
lr: 1e-6
ppo_mini_batch_size: 256
rollout:
name: vllm # vllm/sglang/hf
tensor_model_parallel_size: 2
n: 8 # GRPO group size
trainer:
n_gpus_per_node: 8
nnodes: 1
total_epochs: 30
save_freq: 10
dag:
custom_pipeline_fn: null # Custom Pipeline
Extension Points
----------------
Custom Pipeline
~~~~~~~~~~~~~~~
Add new functions in ``siirl/execution/dag/builtin_pipelines.py``.
Custom Node Functions
~~~~~~~~~~~~~~~~~~~~~
Implement functions following the signature:
.. code-block:: python
def my_node(batch, config=None, **kwargs) -> NodeOutput:
return NodeOutput(batch=batch, metrics={})
Custom Reward Manager
~~~~~~~~~~~~~~~~~~~~~
Add new classes in ``siirl/engine/reward_manager/``.
Custom Environment
~~~~~~~~~~~~~~~~~~
Add new environment classes in ``siirl/environment/``.
================================================
FILE: docs/programming_guide/siiRL_code_explained.rst
================================================
siiRL's Implementation Explained
================================
siiRL is under active development with an extensive roadmap for future enhancements. We strongly encourage community participation in this endeavor. Contributions in any form are highly valued, including but not limited to: filing issues, proposing new features, enhancing documentation, and providing suggestions for improvement.
Overall Implementation
----------------------
RL training itself has clear workflow characteristics, and DAG is the mainstream tool for describing workflows. Therefore, the source code of siiRL adopts a DAG-based design pattern. In terms of specific implementation, siiRL abstracts the entire RL training task into a TaskGraph composed of multiple Nodes, each of which implements the ``node.run()`` method to support the abstract orchestration of the top-level TaskGraph. The constructed TaskGraph is submitted to a set of DAGWorkers for execution.
In the context of multi-agent RL training, different DAGWorkers can process different TaskGraphs in parallel, and the data that different TaskGraphs depend on and process may also vary. Therefore, from a structural perspective, siiRL belongs to the MPMD paradigm.
In terms of user usage, in addition to the configurations related to Data/Trainer/Model/RL Algorithm used by mainstream RL frameworks, siiRL also provides DAG config, which supports users to customize workflows. The system will parse the DAG configuration when the training starts and correspondingly construct a TaskGraph instance.
Complex task workflow poses higher requirements for resource scheduling. To achieve fine-grained allocation of GPUs, siiRL implements a set of TaskScheduler, which is responsible for making globally optimal scheduling decisions, such as: how much computing resources to allocate to each TaskGraph, and specifically which GPU devices on which servers to use. Finally, the allocation plan generated by TaskScheduler is handed over to the underlying Ray framework for specific execution, making full use of Ray's distributed computing capabilities.
.. figure:: ../../asset/code_explained/siirl_arch.png
:width: 60%
:align: center
:alt: Overall Architecture of siiRL's Code Implementation
Figure 1: Overall Architecture of siiRL
We will first provide an overview diagram of the siiRL source code implementation, and then, in the following text, we will introduce each part of the diagram in detail according to the actual execution process.
.. figure:: ../../asset/code_explained/overview_diagram.png
:width: 100%
:align: center
:alt: Diagram of Source Code Implementation
Figure 2: Diagram of Source Code Implementation
Environment Abstraction
-----------------------
During initial RL stage of LLMs, the environment typically refers to the datasets used in post-training. siiRL abstracts the concept of environment to uniformly support RL tasks in different application areas, such as MCP calls and SandBox Server in agentic training scenarios, as well as simulators in the embodied AI domain, or real physical environments for agent interaction.
Similar to OpenAI Gym, siiRL defines two core asynchronous methods:
- ``reset()``: Resets the environment to its initial state and returns the initial observation. This function marks the start of a new episode.
- ``step(actions)``: Receives actions from one or multiple agents, executes these actions, updates the environment state, and returns a tuple containing (observation at the next time step, reward, information). This is the main loop for agent-environment interaction.
Taking the MathEnv of mathematical tasks as an example, the environment natively supports multiple agents. The step function receives a complex number of actions, and the returned observations are also an array prepared for each agent.
.. code-block:: python
class MathEnv(BaseEnvironment):
async def reset(self, dp_rank: int, ddp_world_size: int, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None):
# ...
obs = np.array([self.current_state for _ in range(self.n_agents)], dtype=np.object_)
self.step_count = 0
return obs
async def step(self, actions):
# ...
return next_obs, rewards, infos
Control Flow: Pipeline
----------------------
The main pipeline of the siiRL control flow is shown in the figure below. First, load the configuration of the interactive environment, then sequentially complete the initialization of DataBuffer, the loading and parsing of DAG configuration, and the construction of TaskGraph. After the TaskGraph is constructed, the TaskScheduler schedules (makes decisions on) tasks, determining how many GPUs to allocate to each task and calculating the specific allocation topology. Then, use Ray to construct a distributed process group and initialize RayTrainer. Finally, initialize DAGWorker (Ray's Actor) and start the training task.
.. figure:: ../../asset/code_explained/pipeline.png
:width: 40%
:align: center
:alt: Pipeline of Control Flow
Figure 3: Pipeline of Control Flow
DataLoader and DataBuffer
-------------------------
DataLoader is a wrapper for torch's StatefulDataLoader, which, in combination with the custom PartitionedRLHFDataset, is responsible for tasks such as loading, preprocessing, and batching of training data. Different from other RL open-source frameworks, DataLoader in siiRL is also abstracted as a Node (DataLoaderNode) and embedded into the TaskGraph for execution. Under normal cluster scale and RL tasks, siiRL launches a data_loader process for each GPU rank, which is responsible for loading the data shard corresponding to the DAGWorker on the current rank.
.. code-block:: python
class DataLoaderNode(Node):
"""
Represents a data loader node in the DAG.
This version uses the PartitionedRLHFDataset for efficient, memory-safe
distributed data loading. Each rank only loads and processes its own data slice.
"""
def run(self, epoch: Optional[int] = None, is_validation_step: bool = False, **kwargs: Any) -> Any:
"""
Executes the data loading process for a given step or validation.
"""
try:
# for validation
if is_validation_step:
try:
batch = next(self._current_val_iter)
# for training
else:
try:
batch = next(self._current_train_iter)
return batch
DataBuffer is essentially a distributed KV Store, maintained by an independent Ray Actor process. Typically, DataLoader is per-gpu, while DataBuffer is per-server. In static batching mode, siiRL checks the load balance when creating DataBuffer, as shown in the figure below. For example, if the training batch size is 128, it needs to be divisible by the number of servers to ensure that a global batch can be evenly distributed among servers. Similarly, the batch size allocated to a server, after being replicated ``n`` times (the group size in GRPO, or ``n = 1`` if it is PPO), also needs to be divisible by 8 to ensure that it can be evenly distributed among GPUs on the same server.
.. figure:: ../../asset/code_explained/data_loader.png
:width: 95%
:align: center
:alt: Diagram of Source Code Implementation
Figure 4: DataLoader, DataBuffer and Load Balance
TaskGraph Scheduling
--------------------
The core of TaskGraph is a dictionary composed of Nodes, and TaskGraph uses adjacency lists and reverse adjacency lists to represent the connection relationships between these Nodes. Among them, the reverse adjacency list is mainly used for dependency checking, such as Actor's training depending on rollout's generation. Meanwhile, TaskGraph provides a series of graph operation methods, such as adding, deleting, modifying, and querying nodes, DAG verification, copying, and displaying the graph, to implement the management of TaskGraph.
.. code-block:: python
class Node:
"""
Represents a node (task unit) in the DAG.
"""
class TaskGraph:
"""
Represents a Directed Acyclic Graph (DAG) of tasks,
composed of multiple Node objects and their dependencies.
"""
def __init__(self, graph_id: str):
"""
Initialize a task graph.
Parameters:
graph_id (str): The unique identifier of the graph.
"""
self.graph_id: str = graph_id
self.nodes: Dict[str, Node] = {}
self.adj: Dict[str, List[str]] = {}
self.rev_adj: Dict[str, List[str]] = {}
The scheduling of TaskGraph includes four key steps:
1. **TaskGraph Splitting**: When a user-defined workflow contains parallel paths—as seen in multi-agent training where agents use both shared and specific Nodes—siiRL splits the original TaskGraph into multiple subgraphs for sequential execution. While this approach may not be the most efficient, it significantly simplifies resource scheduling.
2. **SubGraph Sorting**: To allocate resources reasonably, siiRL sorts all SubGraphs. The sorting is mainly based on two points. First, the size of the SubGraph, where this size refers to the parameter scale of the model to be trained on the current SubGraph (7B, 32B, 671B, etc.), with priority given to resource allocation for SubGraphs with larger parameter scales. Second, the number of Nodes on the SubGraph; the more Nodes, i.e., the "longer the chain" of the SubGraph, the earlier it is allocated.
3. **GPU Quota Allocation**: Based on the sorting results from Step 2, allocate the number of GPUs to each SubGraph. There are two allocation strategies: even and param_aware. In the even mode, the total number of GPUs is evenly distributed among SubGraphs as much as possible; in the param_aware mode, on the premise that each subgraph is allocated at least one GPU, subgraphs with larger sizes are allocated more GPUs as much as possible.
4. **GPU Topology Allocation**: With the allocation of the number of GPUs in Step 3, this step performs topology allocation. Suppose there are three SubGraphs, denoted as sg1, sg2, sg3, the training cluster consists of 2 machines with 16 GPUs, and the allocation result regarding the number in Step 3 is: (6, 5, 5), this step will determine "specifically, which 6 GPUs are allocated to sg1, which 6 to sg2, and finally, which 5 to sg3". siiRL makes decisions through a scoring mechanism:
``(cohesion_score(+), node_load_score(-), rank_preference_score(-))``
Where: ``cohesion_score`` is the cohesion score: place a subgraph within the same server as much as possible to reduce communication; ``node_load_score`` is the load penalty: balance placement among servers as much as possible; ``rank_preference_score`` represents the rank partial order: place tasks on GPUs with smaller rank numbers as much as possible to make the scheduling behavior more predictable.
.. figure:: ../../asset/code_explained/taskgraph_sched.png
:width: 95%
:align: center
:alt: TaskGraph Scheduling
Figure 5: TaskGraph Scheduling
Build the Distributed Process Group
-----------------------------------
After task scheduling is completed, the distributed process group of Ray can be constructed. According to the topology determined by the above scheduling, construct the affiliated process group for each Node of the TaskGraph.
For example, actor's training (described as ``NodeRole=Actor, NodeType=Train`` in siiRL), if the assigned ranks are ``[0, 1, 2, 3, 4, 5]``, then use Python's Tuple as the key and a unique string as the value for naming: ``(0,1,2,3,4,5): "process_group_1"``
.. figure:: ../../asset/code_explained/dist_pg.png
:width: 95%
:align: center
:alt: Distributed Process Group
Figure 6: Distributed Process Group
Ray Trainer
-----------
After constructing the process group, initialize RayTrainer. This part is similar to the practices of other mainstream frameworks, with the core being the instantiation of Ray's resource pool management, i.e., resource_manager. Finally, collectively validate the configurations of all Nodes (Actor/Rollout/Reward, etc.).
.. figure:: ../../asset/code_explained/ray_trainer.png
:width: 95%
:align: center
:alt: Ray Trainer
Figure 7: Ray Trainer
DAGWorker
---------
Through a series of abstractions regarding DAG and TaskGraph, siiRL encapsulates and hides the training job flow beneath the control flow. The call logic related to training backend, inference backend, sharding manager, etc., which is directly visible in the control flow of veRL, is all encapsulated into DAGWorker in siiRL and is almost invisible in the control flow. In terms of programming mode, this hiding provides a higher level of abstraction, offering more convenient modular reuse and more flexible extensibility compared to other mainstream frameworks, but it may additionally increase the complexity of bug localization.
In terms of source code implementation, DAGWorker uses mixin classes for modularization. The core mixin classes include 5, which are responsible for initialization, pipeline execution, execution of specific Nodes, training validation, and utility functions, respectively, as shown below.
.. figure:: ../../asset/code_explained/dag_worker.png
:width: 70%
:align: left
:alt: DAG Worker
When initializing DAGWorker, first call resource_manager (the one created during RayTrainer initialization) to create ResourcePool, then create RayActorManager to manage the lifecycle of all distributed DAGWorkers. Finally, call the method defined in the InitializationMixin mixin class to complete the initialization of DAGWorker.
.. figure:: ../../asset/code_explained/dag_init.png
:width: 80%
:align: center
:alt: Initialization of DAG Worker
Figure 8: Initialization of DAG Worker
When setting up the communication group, siiRL adopts the following strategy: if the total number of ranks is less than 256, it uses the pure NCCL backend; otherwise, it uses the GLOO+NCCL hybrid backend. In the hybrid backend mode, GLOO is mainly used for aggregated communication of data such as logs and metrics.
Training Initiation
-------------------
The main pipeline initiates training in the final step. Here, it primarily calls the ``execute_task_graph`` method in the ExecutionMixin mixin class. This method encapsulates the outer loop of epochs and the inner loop of batches within each epoch (i.e., a training step).
.. figure:: ../../asset/code_explained/train_init.png
:width: 70%
:align: center
:alt: Training Job Initialization
Figure 9: Training Job Initialization
Each training step is no longer "concrete and expanded", as in mainstream frameworks such as veRL, but rather "abstract and cyclic": traverse all Nodes in the Graph, for each Node, execute the run method, and write the resulting data to the DataBuffer, where the key is the node_id of the next node and the value is the output of the run method.
.. figure:: ../../asset/code_explained/data_buffer_loop.png
:width: 70%
:align: center
:alt: Loop of TaskGraph Computation based on DataBuffer
Figure 10: Loop of TaskGraph Computation based on DataBuffer
================================================
FILE: docs/programming_guide/siirl_architecture_guide.rst
================================================
=======================================
siiRL Complete Architecture Guide
=======================================
.. note::
**Target Audience**: This document assumes no prior knowledge of siiRL, but expects basic familiarity with Python, PyTorch, and reinforcement learning concepts.
We will systematically explain siiRL's design philosophy, architecture implementation, and core algorithms from the ground up.
Table of Contents
=================
- :ref:`sec1_overview`
- :ref:`sec2_design_philosophy`
- :ref:`sec3_main_entry`
- :ref:`sec4_dag_planner`
- :ref:`sec5_dag_worker`
- :ref:`sec6_data_coordinator`
- :ref:`sec7_engine`
- :ref:`sec8_core_algorithms`
- :ref:`sec9_execution_flow`
- :ref:`sec10_configuration`
- :ref:`sec11_extension_guide`
----
.. _sec1_overview:
1. siiRL Architecture Overview
==============================
1.1 What is siiRL?
------------------
**siiRL** (Shanghai Innovation Institute RL Framework) is a novel **fully distributed reinforcement learning framework** designed to break the scaling barriers in LLM post-training. By eliminating the centralized controller common in other frameworks, siiRL achieves:
- **Near-Linear Scalability**: The multi-controller paradigm eliminates central bottlenecks by distributing control logic and data management across all workers
- **SOTA Throughput**: Fully distributed dataflow architecture minimizes communication and I/O overhead
- **Flexible DAG-Defined Pipeline**: Decouples algorithmic logic from physical hardware, enabling rapid experimentation
1.2 System Architecture and Data Flow
-------------------------------------
**System Architecture Diagram**:
.. figure:: https://github.com/sii-research/siiRL/raw/main/asset/overview.png
:width: 100%
:alt: siiRL Architecture Overview
:align: center
**Figure 1.1**: siiRL System Architecture showing the three core components: DAG Planner, DAG Workers, and Data Coordinator
**Complete Training Step Sequence Diagram**:
The following sequence diagram shows the complete data flow for a single GRPO training step:
::
User MainRunner DAGWorker DataCoordinator Engine
(YAML) (Planner) (per GPU) (Singleton) Workers
| | | | |
============================================================================
| INITIALIZATION PHASE |
============================================================================
| | | | |
| 1. Config | | | |
|-------------->| | | |
| | | | |
| | 2. load_pipeline() + TaskScheduler.schedule() |
| |------------------------------------------------>|
| | | | |
| | 3. Create DAGWorkers (one per GPU) |
| |-------------->| | |
| | | | |
| | | 4. init_graph() | |
| | | Load models | |
| | |-------------------------------->|
| | | | |
============================================================================
| TRAINING LOOP (per step) |
============================================================================
| | | | |
| | | 5. DataLoader | |
| | | .run() | |
| | |<----------------| |
| | | batch (prompts) | |
| | | | |
| | | 6. Node: rollout_actor |
| | |-------------------------------->|
| | | Rollout.generate_sequences()|
| | |<--------------------------------|
| | | batch + responses |
| | | | |
| | | 7. Node: function_reward |
| | | compute_reward() |
| | |---------------->| |
| | | batch + scores | |
| | | | |
| | | 8. Node: calculate_advantages |
| | | compute_advantage() |
| | | (GRPO group normalization) |
| | | | |
| | | 9. put_data_to_buffers() |
| | | (if DP size changes) |
| | |---------------->| |
| | | | ray.put() |
| | | | |
| | | 10. get_data_from_buffers() |
| | |<----------------| |
| | | redistributed batch |
| | | | |
| | | 11. Node: actor_old_log_prob |
| | |-------------------------------->|
| | | Actor.compute_log_prob() |
| | |<--------------------------------|
| | | batch + old_log_probs |
| | | | |
| | | 12. Node: reference_log_prob |
| | |-------------------------------->|
| | | Reference.compute_ref_log_prob|
| | |<--------------------------------|
| | | batch + ref_log_probs |
| | | | |
| | | 13. Node: actor_train |
| | |-------------------------------->|
| | | Actor.update_actor() |
| | | - Forward pass |
| | | - Compute policy loss |
| | | - Backward pass |
| | | - Optimizer step |
| | |<--------------------------------|
| | | metrics |
| | | | |
| | | 14. sync_weights_actor_to_rollout
| | |-------------------------------->|
| | | ShardingManager.sync() |
| | | | |
| | | 15. Log metrics + checkpoint |
| | | | |
============================================================================
| REPEAT for next training step |
============================================================================
**Data Flow Summary**:
::
GRPO Single Step Data Flow
==============================================================================
DataLoader
|
| batch: {prompts, attention_mask, index}
v
+---------------------+
| rollout_actor | DAGWorker.generate()
| (MODEL_INFERENCE) | -> Rollout.generate_sequences()
+----------+----------+
| + {responses, response_ids, response_mask}
v
+---------------------+
| function_reward | DAGWorker.compute_reward()
| (COMPUTE) | -> RewardManager.compute_reward()
+----------+----------+
| + {token_level_scores, token_level_rewards}
v
+---------------------+
| calculate_advantages| DAGWorker.compute_advantage()
| (COMPUTE) | -> compute_grpo_outcome_advantage()
+----------+----------+ Group by prompt -> Normalize (score - mean)/std
| + {advantages}
v
+---------------------+
| actor_old_log_prob | DAGWorker.compute_old_log_prob()
| (MODEL_TRAIN) | -> Actor.compute_log_prob()
| only_forward=True |
+----------+----------+
| + {old_log_probs}
v
+---------------------+
| reference_log_prob | DAGWorker.compute_ref_log_prob()
| (MODEL_TRAIN) | -> Reference.compute_ref_log_prob()
+----------+----------+
| + {ref_log_prob}
v
+---------------------+
| actor_train | DAGWorker.train_actor()
| (MODEL_TRAIN) | -> Actor.update_actor()
+----------+----------+ policy_loss = -advantages * clip(ratio)
|
| metrics: {loss, clipfrac, kl, lr, ...}
v
+---------------------+
| sync_weights | ShardingManager.sync_weights_actor_to_rollout()
+---------------------+
1.3 Core Component Responsibilities
-----------------------------------
.. list-table:: siiRL Core Components
:header-rows: 1
:widths: 20 20 60
* - Component
- Process/Actor
- Core Responsibilities
* - **DAG Planner**
- MainRunner Actor
- Parse user-defined DAG workflows, generate execution plans, assign tasks to workers
* - **DAG Worker**
- One Actor per GPU
- Core execution unit responsible for model initialization, task execution, data flow management
* - **Data Coordinator**
- Global Singleton Actor
- Manage distributed data lifecycle including data loading and intermediate data redistribution
* - **TaskScheduler**
- Inside MainRunner
- Split and assign TaskGraph to each DAG Worker
* - **ProcessGroupManager**
- Inside MainRunner
- Manage creation and configuration of distributed communication groups (TP/PP/DP)
* - **MetricWorker**
- Standalone Actor
- Distributed metrics collection and aggregation
1.4 Why is siiRL Different?
---------------------------
**Problems with Traditional Frameworks**:
1. **Single-Controller Bottleneck**: All data flows through a single node, causing I/O and communication overhead
2. **Rigid Algorithm Pipelines**: Modifying workflows requires deep modifications to framework source code
**siiRL's Solutions**:
.. list-table:: siiRL Design Advantages
:header-rows: 1
:widths: 25 35 40
* - Traditional Frameworks
- siiRL DistFlow
- Advantage
* - Centralized Controller
- Multi-Controller Paradigm
- Eliminates single-point bottleneck, near-linear scaling
* - Hard-coded Workflows
- DAG-Defined Pipeline
- Declarative configuration, no code modification needed
* - Centralized Data Management
- Distributed Data Coordinator
- Avoids OOM, parallelizes data loading
----
.. _sec2_design_philosophy:
2. DistFlow Design Philosophy
=============================
2.1 Fully Distributed Architecture
----------------------------------
The core idea of DistFlow is **"no central coordinator"**. Each DAG Worker is an independent execution unit with its own:
- Data loader (partitioned dataset)
- Model instances (Actor/Critic/Rollout/Reference/Reward)
- Task execution graph (subgraph of TaskGraph)
- Local data cache
2.2 Three-Layer Architecture Design
-----------------------------------
::
┌─────────────────────────────────────────────────────────────────┐
│ User Configuration Layer (YAML/Python) │
│ - workflow_grpo.yaml: Define algorithm DAG │
│ - config.yaml: Model, data, training parameters │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Execution Scheduling Layer (DAG Planner) │
│ - TaskScheduler: Task assignment │
│ - ProcessGroupManager: Communication group management │
│ - GraphUpdater: Configuration injection │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Distributed Execution Layer (DAG Workers) │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │Worker 0 │ │Worker 1 │ │Worker 2 │ │Worker N │ │
│ │ (GPU 0) │ │ (GPU 1) │ │ (GPU 2) │ │ (GPU N) │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Data Coordination Layer (Data Coordinator) │
│ - Distributed DataLoader: Partitioned data loading │
│ - Distributed DataBuffer: Intermediate data redistribution │
└─────────────────────────────────────────────────────────────────┘
2.3 Core Design Principles
--------------------------
.. list-table:: DistFlow Design Principles
:header-rows: 1
:widths: 25 75
* - Principle
- Description
* - **Worker Autonomy**
- Each DAG Worker is a fully independent execution unit, not dependent on central coordination
* - **Data Locality**
- Data is processed locally as much as possible, reducing cross-node transfers
* - **Declarative Workflows**
- Algorithm logic is declared via DAG, decoupled from execution engine
* - **Unified Sample Protocol**
- All intermediate data uses Sample/SampleInfo protocol, supporting flexible routing
* - **Late Binding**
- Configuration is injected into nodes at runtime, supporting dynamic adjustment
----
.. _sec3_main_entry:
3. Program Entry and Startup Flow
=================================
3.1 main_dag.py Explained
-------------------------
``main_dag.py`` is the entry point of siiRL, but unlike traditional frameworks, its role is a **launcher** rather than an executor.
.. code-block:: python
:caption: siirl/main_dag.py Core Structure
def main() -> None:
"""Main entry: Initialize Ray cluster, parse config, start MainRunner"""
# 1. Initialize Ray cluster
if not ray.is_initialized():
ray.init(runtime_env={"env_vars": RAY_RUNTIME_ENV_VARS})
# 2. Parse configuration
siirl_args = parse_config()
# 3. Start main orchestration Actor
runner = MainRunner.remote()
ray.get(runner.run.remote(siirl_args))
3.2 MainRunner Actor
--------------------
``MainRunner`` is the "brain" of the system, responsible for orchestrating the entire training workflow:
.. code-block:: python
:caption: MainRunner.run() Core Flow
@ray.remote(num_cpus=MAIN_RUNNER_CPU_RESERVATION)
class MainRunner:
def run(self, siirl_args: SiiRLArguments) -> None:
# 1. Initialize DataCoordinator
data_coordinator_handle = init_data_coordinator(
num_buffers=siirl_args.trainer.nnodes,
ppo_mini_batch_size=siirl_args.actor_rollout_ref.actor.ppo_mini_batch_size,
world_size=siirl_args.trainer.nnodes * siirl_args.trainer.n_gpus_per_node
)
# 2. Load and configure workflow DAG
workflow_taskgraph = load_pipeline(siirl_args)
update_task_graph_node_configs(workflow_taskgraph, siirl_args)
# 3. Schedule tasks to each worker
task_scheduler = TaskScheduler(siirl_args.trainer.nnodes,
siirl_args.trainer.n_gpus_per_node)
rank_taskgraph_mapping = task_scheduler.schedule_and_assign_tasks([workflow_taskgraph])
# 4. Create process groups
process_group_manager = ProcessGroupManager(total_workers, rank_taskgraph_mapping)
# 5. Create metric worker
metric_worker_handle = MetricWorker.remote()
# 6. Initialize and start DAG Workers
trainer = RayTrainer(config=siirl_args, ...)
trainer.init_workers()
trainer.start_workers()
3.3 Startup Flow Sequence Diagram
---------------------------------
::
main()
│
├── ray.init() ← Initialize Ray cluster
│
├── parse_config() ← Parse YAML configuration
│
└── MainRunner.run()
│
├── init_data_coordinator() ← Create global DataCoordinator
│
├── load_pipeline() ← Load DAG definition
│ │
│ └── grpo_pipeline() ← Return TaskGraph
│
├── TaskScheduler.schedule() ← Assign tasks to each rank
│
├── ProcessGroupManager() ← Create communication group specs
│
├── RayTrainer.init_workers() ← Create DAG Worker Actors
│ │
│ └── DAGWorker.__init__() × N_workers
│
└── RayTrainer.start_workers() ← Start training loop
│
└── DAGWorker.execute_task_graph() × N_workers
----
.. _sec4_dag_planner:
4. DAG Planner Deep Dive
========================
The DAG Planner is siiRL's "scheduling brain", responsible for converting user-defined high-level workflows into executable distributed tasks.
**Pipeline Architecture Overview**:
The following diagram shows how the core data structures relate to each other and how a Pipeline is built and executed:
::
Pipeline Data Structure Relationships
==============================================================================
+------------------+
| Pipeline |
| (Builder) |
+------------------+
| - pipeline_id |
| - description |
| - _nodes: Dict |
+--------+---------+
|
| .build()
v
+------------------+
| TaskGraph |
| (DAG) |
+------------------+
| - graph_id |
| - nodes: Dict |
| - adj: Dict |
| - rev_adj: Dict |
+--------+---------+
|
| contains multiple
v
+----------------+ +----------------+ +----------------+
| Node | | Node | | Node | ...
+----------------+ +----------------+ +----------------+
| - node_id | | - node_id | | - node_id |
| - node_type | | - node_type | | - node_type |
| - node_role | | - node_role | | - node_role |
| - dependencies | | - dependencies | | - dependencies |
| - executable | | - executable | | - executable |
| - config | | - config | | - config |
+----------------+ +----------------+ +----------------+
==============================================================================
NodeType (from node.py) NodeRole (from node.py)
+------------------------+ +------------------------+
| COMPUTE | | DEFAULT |
| DATA_LOAD | | ACTOR |
| ENV_INTERACT | | ADVANTAGE |
| MODEL_INFERENCE | | CRITIC |
| MODEL_TRAIN | | ROLLOUT |
| PUT_TO_BUFFER | | REFERENCE |
| GET_FROM_BUFFER | | REWARD |
| BARRIER_SYNC | | DYNAMIC_SAMPLING |
| CUSTOM | +------------------------+
+------------------------+
**Pipeline Building Flow**:
::
How Pipeline is Built and Executed
================================================================================
Step 1: User Defines Pipeline (Python Code)
--------------------------------------------
pipeline = Pipeline("grpo_training_pipeline")
pipeline.add_node("rollout_actor", func="...:DAGWorker.generate", deps=[])
.add_node("function_reward", func="...:DAGWorker.compute_reward", ...)
.add_node("calculate_advantages", func="...:DAGWorker.compute_advantage", ...)
.add_node("actor_old_log_prob", func="...:DAGWorker.compute_old_log_prob", ...)
.add_node("reference_log_prob", func="...:DAGWorker.compute_ref_log_prob", ...)
.add_node("actor_train", func="...:DAGWorker.train_actor", ...)
|
| pipeline.build()
v
Step 2: Build TaskGraph (Validation + Adjacency Lists)
------------------------------------------------------
TaskGraph Adjacency Lists (adj)
+--------------------+ +------------------------------------------+
| graph_id: "grpo.." | | rollout_actor -> [function_reward] |
| | | function_reward -> [calculate_adv.] |
| nodes: { | | calculate_adv. -> [actor_old_log] |
| "rollout_actor", | | actor_old_log -> [reference_log] |
| "function_reward"| | reference_log -> [actor_train] |
| "calculate_adv.",| | actor_train -> [] |
| ... | +------------------------------------------+
| } |
+--------------------+
|
| TaskScheduler.schedule()
v
Step 3: TaskScheduler Assigns to Workers
----------------------------------------
+------------------------------------------------------------------------+
| TaskScheduler |
| |
| Input: TaskGraph + num_workers |
| |
| 1. discover_and_split_parallel_paths(graph) -> Split parallel branches|
| 2. Apportion workers to subgraphs (param_aware / even) |
| 3. Assign each worker a TaskGraph copy |
| |
| Output: Dict[rank, TaskGraph] (rank_taskgraph_mapping) |
+------------------------------------------------------------------------+
+-------------------------------------------+
| rank_taskgraph_mapping |
+-------------------------------------------+
| rank 0 -> TaskGraph (copy) |
| rank 1 -> TaskGraph (copy) |
| rank 2 -> TaskGraph (copy) |
| ... -> ... |
| rank N -> TaskGraph (copy) |
+-------------------------------------------+
|
| DAGWorker receives TaskGraph
v
Step 4: DAGWorker Executes TaskGraph
------------------------------------
+------------------------------------------------------------------------+
| DAGWorker.execute_task_graph() |
| |
| for each training step: |
| 1. batch = DataLoader.run() |
| 2. entry_nodes = taskgraph.get_entry_nodes() # [rollout_actor] |
| 3. node_queue = entry_nodes |
| |
| while node_queue: |
| cur_node = node_queue.pop(0) |
| |
| # Execute node's function |
| output = cur_node.run(batch=batch, _dag_worker_instance=self) |
| |
| # Resolves executable_ref to actual function: |
| # "siirl.dag_worker.dagworker:DAGWorker.generate" |
| # -> DAGWorker.generate(self, batch, ...) |
| |
| # Get downstream nodes and add to queue |
| next_nodes = taskgraph.get_downstream_nodes(cur_node.node_id) |
| node_queue.extend(next_nodes) |
| |
| # If DP size changes between nodes, use DataCoordinator |
| put_data_to_buffers() / get_data_from_buffers() |
+------------------------------------------------------------------------+
**Execution Order Example (GRPO)**:
::
GRPO Pipeline Execution Order
================================================================================
Topological Order:
+------------------+ +------------------+ +---------------------+
| rollout_actor |----->| function_reward |----->|calculate_advantages |
| (Inference) | | (Compute) | | (Compute) |
| | | | | |
| NodeRole: | | NodeRole: | | NodeRole: |
| ROLLOUT | | REWARD | | ADVANTAGE |
+------------------+ +------------------+ +----------+----------+
|
+----------------------------------------------------------+
|
v
+---------------------+ +---------------------+ +------------------+
| actor_old_log_prob |----->| reference_log_prob |----->| actor_train |
| (Forward Only) | | (Forward Only) | | (Train) |
| | | | | |
| NodeRole: ACTOR | | NodeRole: REFERENCE| | NodeRole: ACTOR |
| only_forward=True | | | | |
+---------------------+ +---------------------+ +------------------+
Data flows through each node, accumulating fields in the batch:
batch: {prompts}
|
v rollout_actor
batch: {prompts, responses, response_ids, response_mask}
|
v function_reward
batch: {..., token_level_scores, token_level_rewards}
|
v calculate_advantages
batch: {..., advantages}
|
v actor_old_log_prob
batch: {..., old_log_probs}
|
v reference_log_prob
batch: {..., ref_log_prob}
|
v actor_train
metrics: {loss, clipfrac, kl, ...}
4.1 Pipeline API
----------------
siiRL provides a clean Pipeline API for users to define training pipelines directly in Python:
.. code-block:: python
:caption: siirl/execution/dag/pipeline.py
class Pipeline:
"""Declarative Pipeline Builder"""
def __init__(self, pipeline_id: str, description: str = ""):
self.pipeline_id = pipeline_id
self._nodes: Dict[str, Dict[str, Any]] = {}
def add_node(
self,
node_id: str,
func: Union[str, Callable], # Function path or direct Callable
deps: Optional[List[str]] = None,
**kwargs
) -> "Pipeline":
"""Add node with method chaining support"""
self._nodes[node_id] = {
"func": func,
"deps": deps or [],
"kwargs": kwargs
}
return self # Support method chaining
def build(self) -> TaskGraph:
"""Build and validate TaskGraph"""
task_graph = TaskGraph(graph_id=self.pipeline_id)
# ... create nodes, build adjacency lists, validate DAG
return task_graph
4.2 Built-in Pipeline Definitions
---------------------------------
siiRL provides four built-in pipeline definitions in ``siirl/execution/dag/builtin_pipelines.py``:
**4.2.1 GRPO Pipeline (grpo_pipeline)**
Standard GRPO (Group Relative Policy Optimization) training workflow:
.. code-block:: python
:caption: siirl/execution/dag/builtin_pipelines.py - GRPO Pipeline
def grpo_pipeline() -> TaskGraph:
"""
Standard GRPO (Group Relative Policy Optimization) pipeline.
Workflow:
1. rollout_actor: Generate sequences using the policy model
2. function_reward: Compute rewards for generated sequences
3. calculate_advantages: Calculate advantage estimates
4. actor_old_log_prob: Compute log probabilities with old policy (forward only)
5. reference_log_prob: Compute log probabilities with reference model
6. actor_train: Train the actor model
"""
pipeline = Pipeline("grpo_training_pipeline", "Standard GRPO workflow")
pipeline.add_node(
"rollout_actor",
func="siirl.dag_worker.dagworker:DAGWorker.generate",
deps=[],
node_type=NodeType.MODEL_INFERENCE,
node_role=NodeRole.ROLLOUT
).add_node(
"function_reward",
func="siirl.dag_worker.dagworker:DAGWorker.compute_reward",
deps=["rollout_actor"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.REWARD
).add_node(
"calculate_advantages",
func="siirl.dag_worker.dagworker:DAGWorker.compute_advantage",
deps=["function_reward"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.ADVANTAGE
).add_node(
"actor_old_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob",
deps=["calculate_advantages"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR,
only_forward_compute=True
).add_node(
"reference_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob",
deps=["actor_old_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.REFERENCE
).add_node(
"actor_train",
func="siirl.dag_worker.dagworker:DAGWorker.train_actor",
deps=["reference_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR
)
return pipeline.build()
**4.2.2 PPO Pipeline (ppo_pipeline)**
Standard PPO with Critic model and GAE advantage estimation:
.. code-block:: python
:caption: siirl/execution/dag/builtin_pipelines.py - PPO Pipeline
def ppo_pipeline() -> TaskGraph:
"""
Standard PPO (Proximal Policy Optimization) pipeline.
Workflow:
1. rollout_actor: Generate sequences using the policy model
2. function_reward: Compute rewards for generated sequences
3. compute_value: Compute value function estimates (forward only)
4. calculate_advantages: Calculate GAE (Generalized Advantage Estimation)
5. actor_old_log_prob: Compute log probabilities with old policy (forward only)
6. reference_log_prob: Compute log probabilities with reference model
7. actor_train: Train the actor model
8. critic_train: Train the critic (value) model
"""
pipeline = Pipeline("ppo_training_pipeline", "Standard PPO workflow")
pipeline.add_node(
"rollout_actor",
func="siirl.dag_worker.dagworker:DAGWorker.generate",
deps=[],
node_type=NodeType.MODEL_INFERENCE,
node_role=NodeRole.ROLLOUT
).add_node(
"function_reward",
func="siirl.dag_worker.dagworker:DAGWorker.compute_reward",
deps=["rollout_actor"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.REWARD
).add_node(
"compute_value",
func="siirl.dag_worker.dagworker:DAGWorker.compute_value",
deps=["function_reward"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.CRITIC,
only_forward_compute=True
).add_node(
"calculate_advantages",
func="siirl.dag_worker.dagworker:DAGWorker.compute_advantage",
deps=["compute_value"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.ADVANTAGE
).add_node(
"actor_old_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob",
deps=["calculate_advantages"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR,
only_forward_compute=True
).add_node(
"reference_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob",
deps=["actor_old_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.REFERENCE
).add_node(
"actor_train",
func="siirl.dag_worker.dagworker:DAGWorker.train_actor",
deps=["reference_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR
).add_node(
"critic_train",
func="siirl.dag_worker.dagworker:DAGWorker.train_critic",
deps=["actor_train"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.CRITIC
)
return pipeline.build()
**4.2.3 DAPO Pipeline (dapo_pipeline)**
DAPO (Data-Augmented Policy Optimization) with dynamic sampling filtering:
.. code-block:: python
:caption: siirl/execution/dag/builtin_pipelines.py - DAPO Pipeline
def dapo_pipeline() -> TaskGraph:
"""
DAPO (Data-Augmented Policy Optimization) pipeline.
DAPO is a variant of GRPO with dynamic sampling filtering based on metric variance.
The key difference is that after computing rewards, we filter out trajectory groups
with zero variance (all correct or all incorrect) as they provide no learning signal.
Workflow:
1. rollout_actor: Generate sequences using the policy model
2. function_reward: Compute rewards for generated sequences
3. dynamic_sampling: DAPO-specific filtering based on metric variance
4. calculate_advantages: Calculate advantage estimates
5. actor_old_log_prob: Compute log probabilities with old policy (forward only)
6. reference_log_prob: Compute log probabilities with reference model
7. actor_train: Train the actor model
"""
pipeline = Pipeline("dapo_training_pipeline", "DAPO workflow")
pipeline.add_node(
"rollout_actor",
func="siirl.dag_worker.dagworker:DAGWorker.generate",
deps=[],
node_type=NodeType.MODEL_INFERENCE,
node_role=NodeRole.ROLLOUT
).add_node(
"function_reward",
func="siirl.dag_worker.dagworker:DAGWorker.compute_reward",
deps=["rollout_actor"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.REWARD
).add_node(
"dynamic_sampling",
func="siirl.user_interface.filter_interface.dapo.dynamic_sampling",
deps=["function_reward"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.DYNAMIC_SAMPLING
).add_node(
"calculate_advantages",
func="siirl.dag_worker.dagworker:DAGWorker.compute_advantage",
deps=["dynamic_sampling"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.ADVANTAGE
).add_node(
"actor_old_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob",
deps=["calculate_advantages"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR,
only_forward_compute=True
).add_node(
"reference_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob",
deps=["actor_old_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.REFERENCE
).add_node(
"actor_train",
func="siirl.dag_worker.dagworker:DAGWorker.train_actor",
deps=["reference_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR
)
return pipeline.build()
**4.2.4 Embodied SRPO Pipeline (embodied_srpo_pipeline)**
Embodied AI SRPO training with data filtering and VJEPA-based reward computation:
.. code-block:: python
:caption: siirl/execution/dag/builtin_pipelines.py - Embodied SRPO Pipeline
def embodied_srpo_pipeline() -> TaskGraph:
"""
Embodied AI GRPO training pipeline with data filtering and VJEPA-based reward computation.
Workflow:
1. rollout_actor: Environment rollout with embodied AI agent
2. embodied_sampling: Data verification and filtering
3. data_rebalance: Data rebalancing across workers (after filtering)
4. compute_reward: VJEPA-based reward computation
5. calculate_advantages: Calculate advantages (GRPO group-based)
6. actor_old_log_prob: Compute old actor log probabilities (forward only)
7. reference_log_prob: Compute reference model log probabilities
8. actor_train: Actor training with GRPO
"""
pipeline = Pipeline(
"embodied_grpo_training_pipeline",
"Embodied AI GRPO training workflow with data filtering and VJEPA-based reward computation."
)
pipeline.add_node(
"rollout_actor",
func="siirl.dag_worker.dagworker:DAGWorker.generate",
deps=[],
node_type=NodeType.MODEL_INFERENCE,
node_role=NodeRole.ROLLOUT
).add_node(
"dynaminc_sampling",
func="siirl.user_interface.filter_interface.embodied.embodied_local_rank_sampling",
deps=["rollout_actor"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.DYNAMIC_SAMPLING
).add_node(
"compute_reward",
func="siirl.dag_worker.dagworker:DAGWorker.compute_reward",
deps=["dynaminc_sampling"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.REWARD
).add_node(
"calculate_advantages",
func="siirl.dag_worker.dagworker:DAGWorker.compute_advantage",
deps=["compute_reward"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.ADVANTAGE
).add_node(
"actor_old_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob",
deps=["calculate_advantages"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR,
only_forward_compute=True
).add_node(
"reference_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob",
deps=["actor_old_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.REFERENCE
).add_node(
"actor_train",
func="siirl.dag_worker.dagworker:DAGWorker.train_actor",
deps=["reference_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR
)
return pipeline.build()
**Pipeline Comparison Table**:
.. list-table:: Built-in Pipeline Comparison
:header-rows: 1
:widths: 15 45 40
* - Pipeline
- Key Difference
- Use Case
* - **GRPO**
- Group-based advantage normalization
- Reasoning tasks, math problems
* - **PPO**
- Critic model + GAE advantage estimation
- General RL tasks with value function
* - **DAPO**
- Dynamic sampling to filter zero-variance groups
- Challenging tasks with sparse rewards
* - **Embodied SRPO**
- Environment interaction + VJEPA reward + dynamic sampling
- Robotics, embodied AI tasks
4.3 Node Data Structure
-----------------------
Each DAG node is represented by the ``Node`` class:
.. code-block:: python
:caption: siirl/execution/dag/node.py
class NodeType(Enum):
"""Define the types of nodes in the DAG."""
COMPUTE = "COMPUTE" # General computing task
DATA_LOAD = "DATA_LOAD" # Load data from DataLoader
ENV_INTERACT = "ENV_INTERACT" # Interact with the environment
MODEL_INFERENCE = "MODEL_INFERENCE" # Model inference (Rollout)
MODEL_TRAIN = "MODEL_TRAIN" # Model training
PUT_TO_BUFFER = "PUT_TO_BUFFER" # Put data into the distributed buffer
GET_FROM_BUFFER = "GET_FROM_BUFFER" # Get data from the distributed buffer
BARRIER_SYNC = "BARRIER_SYNC" # Global synchronization point
CUSTOM = "CUSTOM" # User-defined node type
class NodeRole(Enum):
"""Define the roles that a node plays in a distributed RL framework."""
DEFAULT = "DEFAULT" # Default role
ACTOR = "ACTOR" # Actor model (policy)
ADVANTAGE = "ADVANTAGE" # Advantage computation
CRITIC = "CRITIC" # Critic model (value function)
ROLLOUT = "ROLLOUT" # Rollout inference engine
REFERENCE = "REFERENCE" # Reference model (for KL)
REWARD = "REWARD" # Reward computation
DYNAMIC_SAMPLING = "DYNAMIC_SAMPLING" # Dynamic sampling in databuffer (DAPO/Embodied)
class NodeStatus(Enum):
"""Define the execution status of a DAG node."""
PENDING = "PENDING" # Waiting for dependencies to complete
READY = "READY" # Dependencies completed, ready to execute
RUNNING = "RUNNING" # Currently executing
COMPLETED = "COMPLETED" # Execution completed successfully
FAILED = "FAILED" # Execution failed
SKIPPED = "SKIPPED" # Skipped
class Node:
"""Represents a node (task unit) in the DAG."""
def __init__(
self,
node_id: str,
node_type: NodeType,
node_role: NodeRole = NodeRole.DEFAULT,
only_forward_compute: bool = False, # Forward only, no weight update
agent_group: int = 0, # Multi-agent scenario grouping
dependencies: Optional[List[str]] = None,
config: Optional[Dict[str, Any]] = None,
executable_ref: Optional[str] = None, # Function path "module:Class.method"
filter_plugin: Optional[Callable] = None, # Filter function for data
agent_options: AgentArguments = None,
retry_limit: int = 0,
):
self.node_id = node_id
self.node_type = node_type
self.node_role = node_role
self.only_forward_compute = only_forward_compute
self.agent_group = agent_group
self.dependencies = dependencies or []
self.config = config or {}
self.executable_ref = executable_ref
self.retry_limit = retry_limit
self._executable: Optional[Callable] = None
self.status = NodeStatus.PENDING
# Resolve executable function from path
if self.executable_ref:
self._resolve_executable()
def _resolve_executable(self) -> None:
"""Dynamically import and obtain the executable function.
Supports two formats:
1. "module.path:ClassName.method" - imports module.path, gets ClassName.method
2. "module.path.function" - imports module.path, gets function
"""
if ":" in self.executable_ref:
module_path, attr_path = self.executable_ref.split(":", 1)
module = importlib.import_module(module_path)
obj = module
for attr_name in attr_path.split("."):
obj = getattr(obj, attr_name)
self._executable = obj
else:
module_path, function_name = self.executable_ref.rsplit(".", 1)
module = importlib.import_module(module_path)
self._executable = getattr(module, function_name)
def run(self, **kwargs) -> Any:
"""Execute the task of the node."""
if self.executable:
return self.executable(**kwargs)
4.4 TaskGraph Data Structure
----------------------------
``TaskGraph`` represents the entire training workflow as a DAG:
.. code-block:: python
:caption: siirl/execution/dag/task_graph.py
class TaskGraph:
"""Directed Acyclic Graph representing training workflow"""
def __init__(self, graph_id: str):
self.graph_id = graph_id
self.nodes: Dict[str, Node] = {} # node_id -> Node
self.adj: Dict[str, List[str]] = {} # Forward adjacency: node -> dependents
self.rev_adj: Dict[str, List[str]] = {} # Reverse adjacency: node -> dependencies
def add_node(self, node: Node) -> None:
"""Add node to graph"""
self.nodes[node.node_id] = node
self._update_adj_for_node(node)
def get_topological_sort(self) -> List[str]:
"""Topological sort using Kahn's algorithm"""
# ... implement Kahn's algorithm
def validate_graph(self) -> Tuple[bool, Optional[str]]:
"""Validate DAG validity (no cycles, dependencies exist)"""
# 1. Check all dependencies exist
# 2. Use topological sort to detect cycles
try:
self.get_topological_sort()
return True, None
except ValueError as e:
return False, str(e)
def get_entry_nodes(self) -> List[Node]:
"""Get entry nodes (no dependencies)"""
return [node for node_id, node in self.nodes.items()
if not self.rev_adj.get(node_id)]
def get_downstream_nodes(self, node_id: str) -> List[Node]:
"""Get downstream nodes"""
return self.get_dependents(node_id)
4.5 TaskScheduler
-----------------
``TaskScheduler`` is responsible for assigning TaskGraph to each worker:
.. code-block:: python
:caption: siirl/execution/scheduler/task_scheduler.py
class TaskScheduler:
"""Task Scheduler: Assign TaskGraph to Workers"""
def __init__(self, num_physical_nodes: int, gpus_per_node: int):
self.num_physical_nodes = num_physical_nodes
self.gpus_per_node = gpus_per_node
self.num_workers = num_physical_nodes * gpus_per_node
# State variables
self.worker_to_graph_assignment: Dict[int, Optional[TaskGraph]] = {}
self.node_active_worker_count: Dict[int, int] = defaultdict(int)
self.node_free_gpus: Dict[int, List[int]] = defaultdict(list)
def schedule_and_assign_tasks(
self,
original_task_graphs: List[TaskGraph],
apportion_strategy: str = "param_aware", # or "even"
consider_node_cohesion: bool = True, # Same task on same physical node
consider_node_load: bool = True, # Prefer lower load nodes
) -> Dict[int, Optional[TaskGraph]]:
"""Schedule tasks to each worker"""
# 1. Split original graphs into irreducible subgraphs
all_subgraphs = []
for graph in original_task_graphs:
subgraphs = discover_and_split_parallel_paths(graph)
all_subgraphs.extend(subgraphs)
# 2. Estimate subgraph sizes and sort
subgraphs_with_sizes = sorted(
[(sg, estimate_graph_model_params(sg)) for sg in all_subgraphs],
key=lambda x: x[1],
reverse=True
)
# 3. Apportion worker counts
workers_per_task = self._apportion_workers_to_tasks(
subgraphs_with_sizes,
self.num_workers,
apportion_strategy
)
# 4. Place workers (considering cohesion and load balancing)
for task_graph, _ in subgraphs_with_sizes:
num_workers = workers_per_task[task_graph.graph_id]
for _ in range(num_workers):
best_worker = self._find_best_worker(
task_graph, consider_node_cohesion, consider_node_load
)
self.worker_to_graph_assignment[best_worker] = task_graph
return self.worker_to_graph_assignment
**Scheduling Strategy Comparison**:
.. list-table:: Scheduling Strategies
:header-rows: 1
:widths: 20 40 40
* - Strategy
- Description
- Use Case
* - **even**
- Distribute workers evenly among tasks
- Similar task workloads
* - **param_aware**
- Distribute based on model parameter ratio
- Large variance in task sizes
4.6 Task Graph Splitting (task_loader.py)
-----------------------------------------
The ``task_loader.py`` module provides utilities for analyzing and splitting complex TaskGraphs:
.. code-block:: python
:caption: siirl/execution/dag/task_loader.py
def discover_and_split_parallel_paths(src_task_graph: TaskGraph) -> List[TaskGraph]:
"""
Discovers and splits a TaskGraph into irreducible subgraphs by iteratively
identifying and splitting fan-out and re-converging parallel paths.
Args:
src_task_graph: The original TaskGraph to be analyzed and split
Returns:
List[TaskGraph]: A list of irreducible subgraph TaskGraph objects
"""
# 1. Try to split by fan-out to distinct exits
graphs_after_fan_out = split_by_fan_out_to_exits(current_graph, iteration_counter)
# 2. If no fan-out split, try to split by re-converging paths
graphs_after_reconverge = split_by_reconverging_paths(current_graph, iteration_counter)
# 3. If no split possible, graph is irreducible
return final_irreducible_graphs
This enables automatic parallelization of independent pipeline branches across different worker groups.
----
.. _sec5_dag_worker:
5. DAG Worker Deep Dive
=======================
DAG Worker is the core execution unit of siiRL, with one DAG Worker running per GPU.
5.1 DAGWorker Class Structure
-----------------------------
.. code-block:: python
:caption: siirl/dag_worker/dagworker.py
class DAGWorker(Worker):
"""DAG Execution Unit, one instance per GPU"""
def __init__(
self,
config: SiiRLArguments,
process_group_manager: ProcessGroupManager,
taskgraph_mapping: Dict[int, TaskGraph],
data_coordinator: ray.actor.ActorHandle,
metric_worker: ray.actor.ActorHandle,
):
# Configuration
self.config = config
self.process_group_manager = process_group_manager
self.taskgraph_mapping = taskgraph_mapping
self.data_coordinator = data_coordinator
# State
self.global_steps = 0
self.workers: Dict[str, Any] = {} # Node role -> Worker instance
self.multi_agent_group: Dict[int, Dict[NodeRole, Any]] = defaultdict(dict)
self.process_groups: Dict[str, ProcessGroup] = {}
self.internal_data_cache: Dict[str, Any] = {}
# Initialize
self._initialize_worker()
5.2 Initialization Flow
-----------------------
DAGWorker initialization is divided into two phases:
**Phase 1: _initialize_worker() in __init__**
.. code-block:: python
def _initialize_worker(self):
"""Initialize all Worker components"""
# 1. Validate rank and get assigned TaskGraph
self._rank = get_and_validate_rank()
self.taskgraph = get_taskgraph_for_rank(self._rank, self.taskgraph_mapping)
# 2. Set up distributed environment
self._setup_distributed_environment()
# 3. Initialize Tokenizer
self._setup_tokenizers()
# 4. Initialize DataLoader
self._setup_dataloader()
# 5. Initialize Reward Manager
self._setup_reward_managers()
# 6. Create role -> Worker class mapping
self._setup_role_worker_mapping()
# 7. Instantiate node Workers
self._initialize_node_workers()
**Phase 2: init_graph() method**
.. code-block:: python
def init_graph(self):
"""Load model weights, restore checkpoint"""
# 1. Load model weights to GPU
self._load_model_weights()
# 2. Set up weight sharing (Actor-Rollout)
self._setup_sharding_manager()
# 3. Initialize async rollout (if configured)
self._setup_async_rollout()
# 4. Initialize multi-agent loop (if configured)
self._setup_multi_agent_loop()
# 5. Initialize validator
self._init_validator()
# 6. Initialize checkpoint manager and restore
self._init_checkpoint_manager()
self.global_steps = self.checkpoint_manager.load_checkpoint()
# 7. Global synchronization
dist.barrier(self._gather_group)
5.3 Training Loop
-----------------
.. code-block:: python
:caption: DAGWorker Training Loop Core Logic
def execute_task_graph(self):
"""Main entry: Execute DAG training pipeline"""
# Optional pre-training validation
if self.config.trainer.val_before_train:
self.validator.validate(global_step=self.global_steps)
# Main training loop
self._run_training_loop()
def _run_training_loop(self):
"""Main training loop"""
for epoch in range(self.config.trainer.total_epochs):
for batch_idx in range(self.dataloader.num_train_batches):
# Execute one training step
ordered_metrics = self._run_training_step(epoch, batch_idx)
self.global_steps += 1
# Save checkpoint
if self.global_steps % self.config.trainer.save_freq == 0:
self.checkpoint_manager.save_checkpoint(self.global_steps)
# Execute validation
if self.global_steps % self.config.trainer.test_freq == 0:
self.validator.validate(global_step=self.global_steps)
# Log metrics
if self._rank == 0 and self.logger:
self.logger.log(data=ordered_metrics, step=self.global_steps)
5.4 Single Training Step Execution
----------------------------------
.. code-block:: python
:caption: _run_training_step() Explained
def _run_training_step(self, epoch: int, batch_idx: int) -> Optional[Dict]:
"""Execute a single training step"""
# 1. Get data from DataLoader
batch = preprocess_dataloader(
self.dataloader.run(epoch=epoch, is_validation_step=False),
self.config.actor_rollout_ref.rollout.n
)
# 2. Get DAG entry nodes
node_queue = self.taskgraph.get_entry_nodes()
entry_node_id = node_queue[0].node_id
visited_nodes = set()
# 3. Graph traversal execution
while node_queue:
cur_node = node_queue.pop(0)
if cur_node.node_id in visited_nodes:
continue
visited_nodes.add(cur_node.node_id)
# 3.1 Get node's DP/TP/PP info
cur_dp_size, cur_dp_rank, cur_tp_rank, cur_tp_size, cur_pp_rank, cur_pp_size = \
self._get_node_dp_info(cur_node)
# 3.2 Non-entry nodes get data from buffer
if cur_node.node_id != entry_node_id:
batch = self.get_data_from_buffers(
key=cur_node.node_id,
cur_dp_size=cur_dp_size,
cur_dp_rank=cur_dp_rank
)
# 3.3 Execute node
if cur_node.executable and batch is not None:
node_output = cur_node.run(
batch=batch,
config=self.config,
process_group=self._get_node_process_group(cur_node),
agent_group=self.multi_agent_group[cur_node.agent_group],
_dag_worker_instance=self
)
else:
node_output = NodeOutput(batch=batch)
# 3.4 Process output, pass to downstream nodes
if next_nodes := self.taskgraph.get_downstream_nodes(cur_node.node_id):
next_node = next_nodes[0]
next_dp_size = self._get_node_dp_info(next_node)[0]
# If DP size changes, need DataCoordinator for redistribution
self.put_data_to_buffers(
key=next_node.node_id,
data=node_output.batch,
source_dp_size=cur_dp_size,
dest_dp_size=next_dp_size
)
# Add downstream nodes to queue
for n in next_nodes:
if n.node_id not in visited_nodes:
node_queue.append(n)
# 4. Clean up caches
self._cleanup_step_buffers()
# 5. Collect and return metrics
return self._collect_metrics()
5.5 Node Execution Methods
--------------------------
DAGWorker provides a series of node execution methods, each corresponding to a node role:
.. code-block:: python
:caption: Node Execution Methods
# Rollout: Generate sequences
def generate(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Generate sequences using the Rollout model"""
agent_group = kwargs.pop("agent_group")
is_embodied = self.config.actor_rollout_ref.model.model_type == "embodied"
if is_embodied:
return self.generate_embodied_mode(agent_group, batch, **kwargs)
if self.rollout_mode == 'sync':
gen_output = agent_group[NodeRole.ROLLOUT].generate_sequences(batch)
batch = batch.update(gen_output)
return NodeOutput(batch=batch, metrics=gen_output["metrics"])
else:
return self.generate_async_mode(batch)
# Reward: Compute rewards
def compute_reward(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Compute rewards for generated sequences"""
reward_tensor, extra_infos = compute_reward(batch, self.reward_fn)
batch["token_level_scores"] = reward_tensor
if config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl_in_reward, ...)
else:
batch["token_level_rewards"] = batch["token_level_scores"]
return NodeOutput(batch=batch, metrics=metrics)
# Advantage: Compute advantages
def compute_advantage(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Compute GAE/GRPO/CPGD advantages"""
return NodeOutput(
batch=compute_advantage(
batch,
adv_estimator=config.algorithm.adv_estimator,
gamma=config.algorithm.gamma,
lam=config.algorithm.lam,
norm_adv_by_std_in_grpo=config.algorithm.norm_adv_by_std_in_grpo
)
)
# Actor Forward: Compute old policy log prob
def compute_old_log_prob(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Compute log probabilities before policy update"""
agent_group = kwargs.pop("agent_group")
processed_data = agent_group[NodeRole.ACTOR].compute_log_prob(batch)
return NodeOutput(batch=processed_data, metrics=processed_data.get("metrics", {}))
# Reference: Compute reference model log prob
def compute_ref_log_prob(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Compute reference model log probabilities"""
agent_group = kwargs.pop("agent_group")
processed_data = agent_group[NodeRole.REFERENCE].compute_ref_log_prob(batch)
return NodeOutput(batch=processed_data, metrics=processed_data["metrics"])
# Actor Train: Train Actor model
def train_actor(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Execute Actor model training step"""
agent_group = kwargs.pop("agent_group")
processed_data = agent_group[NodeRole.ACTOR].update_actor(batch)
return NodeOutput(batch=processed_data, metrics=processed_data["metrics"])
# Critic Train: Train Critic model (PPO)
def train_critic(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Execute Critic model training step"""
agent_group = kwargs.pop("agent_group")
processed_data = agent_group[NodeRole.CRITIC].update_critic(batch)
return NodeOutput(batch=processed_data, metrics=processed_data["metrics"])
----
.. _sec6_data_coordinator:
6. Data Coordinator Deep Dive
=============================
Data Coordinator is the core component of siiRL's fully distributed data management.
6.1 Design Philosophy
---------------------
**Why do we need Data Coordinator?**
In traditional frameworks, all intermediate data (Rollout outputs, Reward results, etc.) must pass through a central controller for redistribution, causing severe I/O bottlenecks. siiRL's Data Coordinator adopts a different design:
1. **Store only metadata and references**: Actual data is stored in Ray Object Store
2. **Support flexible sampling strategies**: Custom sampling via filter_plugin
3. **Automatic load balancing**: Optimize sequence length distribution via balance_partitions
6.2 DataCoordinator Implementation
----------------------------------
.. code-block:: python
:caption: siirl/data_coordinator/data_buffer.py
@ray.remote
class DataCoordinator:
"""Global singleton data coordination Actor"""
def __init__(self, nnodes: int, ppo_mini_batch_size: int, world_size: int):
self.nnodes = nnodes
self.ppo_mini_batch_size = ppo_mini_batch_size
self.world_size = world_size
# Efficiently store metadata and references using deque
self._sample_queue: deque[Tuple[SampleInfo, ray.ObjectRef]] = deque()
self.lock = asyncio.Lock()
self._cache = []
async def put_batch(
self,
sample_infos: List[SampleInfo],
sample_refs: List[ray.ObjectRef],
caller_node_id: Optional[str] = None
):
"""Register a batch of sample references and metadata"""
# Inject caller node ID (for subsequent routing)
if caller_node_id is None:
caller_node_id = ray.get_runtime_context().get_node_id()
for i in range(len(sample_infos)):
if sample_infos[i].node_id is None:
sample_infos[i].node_id = caller_node_id
async with self.lock:
self._sample_queue.extend(zip(sample_infos, sample_refs))
async def get_batch(
self,
batch_size: int,
dp_rank: int,
filter_plugin: Optional[Callable[[SampleInfo], bool]] = None,
balance_partitions: Optional[int] = None
) -> List[ray.ObjectRef]:
"""Get a batch of sample ObjectRefs"""
async with self.lock:
# 1. If cached, return directly
if len(self._cache) > 0:
return self._cache[dp_rank]
# 2. No filter, use efficient FIFO
if not filter_plugin:
batch_items = []
while self._sample_queue:
item = self._sample_queue.popleft()
batch_items.append(item)
# Apply length balancing
if balance_partitions and balance_partitions > 1:
batch_refs = self._apply_length_balancing(batch_items, balance_partitions)
else:
batch_refs = [item[1] for item in batch_items]
self._cache = batch_refs
return self._cache[:batch_size]
# 3. With filter, execute filtering
else:
potential_items = [item for item in self._sample_queue
if filter_plugin(item[0])]
global_batch_size = batch_size * balance_partitions
if len(potential_items) < global_batch_size:
return []
potential_items = potential_items[:global_batch_size]
# Remove selected items from queue
refs_to_remove = {item[1] for item in potential_items}
self._sample_queue = deque(
item for item in self._sample_queue if item[1] not in refs_to_remove
)
# Apply length balancing and cache
if balance_partitions and balance_partitions > 1:
batch_refs = self._apply_length_balancing(potential_items, balance_partitions)
else:
batch_refs = [item[1] for item in potential_items]
for rank in range(balance_partitions):
self._cache.append(batch_refs[rank * batch_size: (rank + 1) * batch_size])
return self._cache[dp_rank]
6.3 SampleInfo Metadata
-----------------------
.. code-block:: python
:caption: siirl/data_coordinator/sample.py
@dataclass
class SampleInfo:
"""Sample metadata for routing and sampling"""
sum_tokens: int = 0 # Total tokens (prompt + response)
prompt_length: int = 0 # Prompt length
response_length: int = 0 # Response length
uid: str = "" # Unique identifier
node_id: Optional[str] = None # Source node ID
dict_info: Dict[str, Any] = field(default_factory=dict) # Extended info
# Common fields:
# - 'key': Target node ID
# - 'source_dp_size': Source DP size
6.4 DAGWorker Data Flow Operations
----------------------------------
.. code-block:: python
:caption: Data flow methods in DAGWorker
def put_data_to_buffers(
self,
key: str,
data: TensorDict,
source_dp_size: int,
dest_dp_size: int,
enforce_buffer: bool = False
):
"""Put data into DataCoordinator"""
# Same source and dest DP size and not forcing buffer, use local cache
if source_dp_size == dest_dp_size and not enforce_buffer:
self.internal_data_cache[key] = data
else:
# Convert to Sample list
samples = Dict2Samples(data)
# Create metadata
sample_infos = []
for sample in samples:
sample_infos.append(SampleInfo(
sum_tokens=int(sample.attention_mask.sum()),
uid=str(sample.uid),
dict_info={'key': key, 'source_dp_size': source_dp_size}
))
# Upload to Ray Object Store
sample_refs = [ray.put(sample) for sample in samples]
# Register with DataCoordinator
caller_node_id = ray.get_runtime_context().get_node_id()
self.data_coordinator.put_batch.remote(sample_infos, sample_refs, caller_node_id)
def get_data_from_buffers(
self,
key: str,
cur_dp_size: int,
cur_dp_rank: int
) -> Optional[TensorDict]:
"""Get data from DataCoordinator"""
# Check local cache first
if key in self.internal_data_cache:
return self.internal_data_cache.pop(key)
# Define filter function
def key_filter(sample_info: SampleInfo) -> bool:
return sample_info.dict_info.get('key') == key
# Calculate adjusted batch size
rollout_n = self.config.actor_rollout_ref.rollout.n
adjusted_batch_size = int(self.config.data.train_batch_size * rollout_n / cur_dp_size)
# Get from DataCoordinator
sample_refs = ray.get(self.data_coordinator.get_batch.remote(
adjusted_batch_size,
cur_dp_rank,
filter_plugin=key_filter,
balance_partitions=cur_dp_size
))
if not sample_refs:
return None
# Get actual data and collate
samples = ray.get(sample_refs)
return Samples2Dict(samples)
----
.. _sec7_engine:
7. Engine Model Execution
=========================
The Engine module contains all model Worker implementations, supporting both FSDP and Megatron training backends.
7.1 Engine Module Structure
---------------------------
::
engine/
├── actor/ # Actor models
│ ├── base.py # Base class
│ ├── dp_actor.py # FSDP Actor
│ ├── megatron_actor.py # Megatron Actor
│ └── embodied_actor.py # Embodied Actor
├── critic/ # Critic models
│ ├── base.py
│ ├── dp_critic.py
│ └── megatron_critic.py
├── rollout/ # Rollout engine
│ ├── base.py
│ ├── vllm_rollout/ # vLLM backend
│ ├── sglang_rollout/ # SGLang backend
│ ├── hf_rollout.py # HuggingFace backend
│ └── embodied_rollout.py # Embodied Rollout
├── reward_model/ # Reward models
├── reward_manager/ # Reward managers
│ ├── naive.py # Simple Reward
│ ├── parallel.py # Parallel Reward Model
│ ├── dapo.py # DAPO Reward
│ └── embodied.py # Embodied Reward
├── sharding_manager/ # Weight sharding management
│ ├── base.py
│ ├── fsdp_hf.py
│ ├── fsdp_sglang.py
│ ├── fsdp_vllm.py
│ ├── megatron_sglang.py
│ └── megatron_vllm.py
├── fsdp_workers.py # FSDP Worker factory
└── megatron_workers.py # Megatron Worker factory
7.2 Worker Base Class
---------------------
All model Workers inherit from a unified base class:
.. code-block:: python
:caption: siirl/engine/base_worker/base/base_worker.py
class Worker:
"""Abstract base class for all Workers"""
@property
def world_size(self) -> int:
"""Get global world size"""
if not dist.is_initialized():
return 1
return dist.get_world_size()
def init_model(self):
"""Initialize model weights (implemented by subclasses)"""
raise NotImplementedError
7.3 Actor Worker
----------------
Actor Worker is responsible for policy model training:
.. code-block:: python
:caption: siirl/engine/actor/dp_actor.py (simplified)
class FSDPActor(Actor):
"""FSDP Distributed Actor"""
def __init__(self, config, process_group: ProcessGroup):
self.config = config
self.process_group = process_group
# Model related
self.model = None
self.optimizer = None
self.scheduler = None
def init_model(self):
"""Initialize model, optimizer, scheduler"""
# 1. Load model
self.model = self._load_model()
# 2. Apply FSDP wrapping
self.model = FSDP(
self.model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
process_group=self.process_group,
mixed_precision=...,
)
# 3. Create optimizer
self.optimizer = create_optimizer(self.model, self.config.actor.optim)
# 4. Create learning rate scheduler
self.scheduler = create_scheduler(self.optimizer, self.config.actor.optim)
def compute_log_prob(self, batch: TensorDict) -> TensorDict:
"""Compute log probabilities (forward pass, no weight update)"""
with torch.no_grad():
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
log_probs = compute_log_prob_from_logits(
outputs.logits, batch["responses"], batch["response_mask"]
)
batch["old_log_probs"] = log_probs
return batch
def update_actor(self, batch: TensorDict) -> TensorDict:
"""Execute Actor training step"""
metrics = {}
total_loss = 0.0
for _ in range(self.config.actor.ppo_epochs):
# Forward pass
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
# Compute current log probabilities
log_probs = compute_log_prob_from_logits(
outputs.logits, batch["responses"], batch["response_mask"]
)
# Compute policy loss
pg_loss, pg_clipfrac, ppo_kl, _ = compute_policy_loss(
old_log_prob=batch["old_log_probs"],
log_prob=log_probs,
advantages=batch["advantages"],
response_mask=batch["response_mask"],
cliprange=self.config.actor.clip_ratio,
)
# Compute entropy loss
entropy_loss = compute_entropy_loss(outputs.logits, batch["response_mask"])
# Total loss
loss = pg_loss - self.config.actor.entropy_coef * entropy_loss
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
if self.config.actor.max_grad_norm:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.config.actor.max_grad_norm
)
# Optimizer step
self.optimizer.step()
self.scheduler.step()
total_loss += loss.item()
metrics["actor/loss"] = total_loss / self.config.actor.ppo_epochs
metrics["actor/pg_clipfrac"] = pg_clipfrac.item()
metrics["actor/ppo_kl"] = ppo_kl.item()
batch["metrics"] = metrics
return batch
7.4 Rollout Worker
------------------
Rollout Worker is responsible for sequence generation:
.. code-block:: python
:caption: siirl/engine/rollout/vllm_rollout/vllm_rollout.py (simplified)
class VLLMRollout:
"""vLLM Inference Backend"""
def __init__(self, config, process_group: ProcessGroup):
self.config = config
self.process_group = process_group
# vLLM LLM instance
self.llm = None
self.tokenizer = None
def init_model(self):
"""Initialize vLLM engine"""
from vllm import LLM, SamplingParams
self.llm = LLM(
model=self.config.model.path,
tensor_parallel_size=self.config.rollout.tensor_model_parallel_size,
trust_remote_code=True,
dtype=self.config.model.dtype,
)
self.tokenizer = self.llm.get_tokenizer()
def generate_sequences(self, batch: TensorDict) -> TensorDict:
"""Generate sequences"""
from vllm import SamplingParams
# Build sampling parameters
sampling_params = SamplingParams(
n=self.config.rollout.n, # GRPO group size
temperature=self.config.rollout.temperature,
top_p=self.config.rollout.top_p,
max_tokens=self.config.data.max_response_length,
)
# Prepare prompts
prompts = batch["prompts"] # List[str] or List[List[int]]
# Generate
outputs = self.llm.generate(prompts, sampling_params)
# Process outputs
all_responses = []
all_response_ids = []
for output in outputs:
for completion in output.outputs:
all_responses.append(completion.text)
all_response_ids.append(completion.token_ids)
# Update batch
batch["responses"] = all_responses
batch["response_ids"] = torch.tensor(all_response_ids)
batch["metrics"] = {
"rollout/avg_response_length": np.mean([len(r) for r in all_response_ids])
}
return batch
7.5 Sharding Manager
--------------------
Sharding Manager is responsible for weight synchronization between Actor and Rollout:
.. code-block:: python
:caption: siirl/engine/sharding_manager/fsdp_vllm.py (simplified)
class FSDPVLLMShardingManager:
"""Weight synchronization between FSDP Actor and vLLM Rollout"""
def __init__(self, actor: FSDPActor, rollout: VLLMRollout, process_group: ProcessGroup):
self.actor = actor
self.rollout = rollout
self.process_group = process_group
def sync_weights_actor_to_rollout(self):
"""Sync Actor weights to Rollout"""
# 1. Gather full weights from FSDP
with FSDP.state_dict_type(
self.actor.model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
):
state_dict = self.actor.model.state_dict()
# 2. Broadcast to all ranks
dist.broadcast_object_list([state_dict], src=0, group=self.process_group)
# 3. Update vLLM model weights
self.rollout.load_weights(state_dict)
----
.. _sec8_core_algorithms:
8. Core Algorithm Implementation
================================
8.1 Advantage Estimators
------------------------
siiRL supports multiple advantage estimation methods:
.. code-block:: python
:caption: siirl/dag_worker/core_algos.py
# Registry decorator
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
def register_adv_est(name_or_enum: str | AdvantageEstimator):
"""Register an advantage estimator"""
def decorator(fn):
name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
ADV_ESTIMATOR_REGISTRY[name] = fn
return fn
return decorator
@register_adv_est(AdvantageEstimator.GAE)
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor, # (bs, response_length)
values: torch.Tensor, # (bs, response_length)
response_mask: torch.Tensor, # (bs, response_length)
gamma: float,
lam: float,
):
"""GAE (Generalized Advantage Estimation) for PPO"""
with torch.no_grad():
nextvalues = 0
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam_ = delta + gamma * lam * lastgaelam
# Skip padding tokens
nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues
lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = masked_whiten(advantages, response_mask)
return advantages, returns
@register_adv_est(AdvantageEstimator.GRPO)
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, # (bs, response_length)
response_mask: torch.Tensor, # (bs, response_length)
index: np.ndarray, # Index for grouping
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
):
"""GRPO (Group Relative Policy Optimization)"""
scores = token_level_rewards.sum(dim=-1) # Sequence-level rewards
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
# Group by prompt
for i in range(bsz):
idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])
id2score[idx_key].append(scores[i])
# Compute group mean and std
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
# Normalize
for i in range(bsz):
idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])
if norm_adv_by_std_in_grpo:
scores[i] = (scores[i] - id2mean[idx_key]) / (id2std[idx_key] + epsilon)
else: # Dr.GRPO
scores[i] = scores[i] - id2mean[idx_key]
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
8.2 Policy Loss Functions
-------------------------
siiRL supports multiple policy loss functions:
.. code-block:: python
:caption: siirl/dag_worker/core_algos.py
POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}
def register_policy_loss(name: str):
"""Register a policy loss function"""
def decorator(func: PolicyLossFn) -> PolicyLossFn:
POLICY_LOSS_REGISTRY[name] = func
return func
return decorator
@register_policy_loss("vanilla")
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[ActorArguments] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Standard PPO policy loss (dual-clip)"""
clip_ratio = config.clip_ratio
clip_ratio_low = config.clip_ratio_low or clip_ratio
clip_ratio_high = config.clip_ratio_high or clip_ratio
clip_ratio_c = config.clip_ratio_c
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = masked_mean(-negative_approx_kl, response_mask)
# Standard PPO clipping
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
# Dual clipping (negative advantage scenario)
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
# Apply importance weights
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(pg_losses, response_mask, loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("cpgd")
def compute_policy_loss_cpgd(...):
"""CPGD policy loss (direct log_prob clipping)"""
...
@register_policy_loss("gspo")
def compute_policy_loss_gspo(...):
"""GSPO policy loss (sequence-level importance ratio)"""
...
@register_policy_loss("gpg")
def compute_policy_loss_gpg(...):
"""GPG policy loss (REINFORCE style)"""
...
8.3 KL Penalty
--------------
.. code-block:: python
class AdaptiveKLController:
"""Adaptive KL Controller"""
def __init__(self, init_kl_coef, target_kl, horizon):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
proportional_error = np.clip(current_kl / self.target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
def apply_kl_penalty(data: TensorDict, kl_ctrl, kl_penalty="kl"):
"""Apply KL penalty to token-level rewards"""
kld = kl_penalty_fn(data["old_log_probs"], data["ref_log_prob"], kl_penalty)
kld = kld * data["response_mask"]
beta = kl_ctrl.value
data["token_level_rewards"] = data["token_level_scores"] - beta * kld
current_kl = masked_mean(kld, data["response_mask"]).item()
kl_ctrl.update(current_kl=current_kl, n_steps=data.batch_size[0])
return data, {"actor/reward_kl_penalty": current_kl, "actor/kl_coef": beta}
----
.. _sec9_execution_flow:
9. Complete Execution Flow
==========================
9.1 GRPO Training Flow
----------------------
Using GRPO as an example, showing the complete training flow:
::
┌──────────────────────────────────────────────────────────────────────────────┐
│ GRPO Single Step Training Flow │
└──────────────────────────────────────────────────────────────────────────────┘
[1. Data Loading]
│
│ DataLoader.run() → batch (prompts, attention_mask, ...)
│
▼
[2. Rollout Generation] ───────────────────────────────────────────────────────
│
│ DAGWorker.generate()
│ │
│ ├── Prepare generation batch
│ ├── rollout_worker.generate_sequences(batch)
│ │ │
│ │ ├── vLLM/SGLang/HF inference
│ │ └── Return responses, response_ids
│ │
│ └── Update batch: responses, response_mask
│
│ Output: batch with responses (bs * n_samples, seq_len)
│
▼
[3. Reward Computation] ──────────────────────────────────────────────────────
│
│ DAGWorker.compute_reward()
│ │
│ ├── reward_fn.score(batch) → token_level_scores
│ │
│ ├── (Optional) Apply KL penalty:
│ │ kl = old_log_prob - ref_log_prob
│ │ token_level_rewards = token_level_scores - β * kl
│ │
│ └── Otherwise: token_level_rewards = token_level_scores
│
│ Output: batch with token_level_rewards
│
▼
[4. Advantage Computation] ───────────────────────────────────────────────────
│
│ DAGWorker.compute_advantage()
│ │
│ └── compute_grpo_outcome_advantage()
│ │
│ ├── Compute sequence-level scores: scores = rewards.sum(dim=-1)
│ ├── Group by prompt
│ ├── Compute group mean and std
│ └── Normalize: (scores - mean) / std
│
│ Output: batch with advantages
│
▼
[5. Actor Forward] ───────────────────────────────────────────────────────────
│
│ DAGWorker.compute_old_log_prob()
│ │
│ └── actor_worker.compute_log_prob(batch)
│ │
│ ├── Forward pass (no_grad)
│ └── Compute old_log_probs
│
│ Output: batch with old_log_probs
│
▼
[6. Reference Forward] ───────────────────────────────────────────────────────
│
│ DAGWorker.compute_ref_log_prob()
│ │
│ └── reference_worker.compute_ref_log_prob(batch)
│ │
│ ├── Forward pass (no_grad)
│ └── Compute ref_log_prob
│
│ Output: batch with ref_log_prob
│
▼
[7. Actor Training] ──────────────────────────────────────────────────────────
│
│ DAGWorker.train_actor()
│ │
│ └── actor_worker.update_actor(batch)
│ │
│ ├── for _ in range(ppo_epochs):
│ │ │
│ │ ├── Forward pass → log_probs
│ │ ├── Compute policy loss:
│ │ │ pg_loss = -advantages * clipped_ratio
│ │ ├── Compute entropy loss
│ │ ├── Total loss = pg_loss - entropy_coef * entropy
│ │ ├── Backward pass
│ │ └── Optimizer step
│ │
│ └── Return metrics
│
│ Output: batch with metrics
│
▼
[8. Sync Weights]
│
│ sharding_manager.sync_weights_actor_to_rollout()
│
▼
[Done: Continue to next step]
9.2 PPO Training Flow
---------------------
PPO adds Critic model and GAE computation compared to GRPO:
::
GRPO flow + the following additional steps:
[3.5. Value Computation] (After Reward, before Advantage)
│
│ DAGWorker.compute_value()
│ │
│ └── critic_worker.compute_values(batch)
│ │
│ ├── Forward pass (no_grad)
│ └── Compute values
│
│ Output: batch with values
[4. Advantage Computation] (Uses GAE instead of GRPO)
│
│ compute_gae_advantage_return()
│ │
│ ├── Reverse iterate through response tokens
│ ├── Compute TD-error: δ = r + γV(s') - V(s)
│ └── GAE: A = δ + γλA'
[7.5. Critic Training] (After Actor training)
│
│ DAGWorker.train_critic()
│ │
│ └── critic_worker.update_critic(batch)
│ │
│ ├── Forward pass → vpreds
│ ├── Compute Value loss:
│ │ vf_loss = clipped_mse(vpreds, returns)
│ ├── Backward pass
│ └── Optimizer step
----
.. _sec10_configuration:
10. Configuration Parameters
============================
10.1 Configuration File Structure
---------------------------------
siiRL uses Hydra for configuration management, with main configuration groups:
.. code-block:: yaml
:caption: Configuration File Structure
# algorithm: Algorithm configuration
algorithm:
adv_estimator: grpo # grpo/gae/cpgd/gspo
workflow_type: DEFAULT # DEFAULT/DAPO/EMBODIED
gamma: 1.0 # Discount factor
lam: 0.95 # GAE lambda
use_kl_in_reward: false # Whether to use KL penalty in reward
norm_adv_by_std_in_grpo: true
kl_ctrl:
type: fixed # fixed/adaptive
kl_coef: 0.001
# data: Data configuration
data:
train_files: /path/to/train.parquet
train_batch_size: 512
max_prompt_length: 2048
max_response_length: 4096
num_loader_workers: 4
# actor_rollout_ref: Model configuration
actor_rollout_ref:
model:
path: /path/to/model
dtype: bfloat16
trust_remote_code: true
actor:
strategy: fsdp # fsdp/megatron
clip_ratio: 0.2
entropy_coef: 0.01
ppo_epochs: 1
ppo_mini_batch_size: 256
max_grad_norm: 1.0
optim:
lr: 1e-6
weight_decay: 0.01
scheduler: cosine_with_warmup
warmup_ratio: 0.1
rollout:
name: vllm # vllm/sglang/hf
tensor_model_parallel_size: 2
n: 8 # GRPO group size
temperature: 1.0
top_p: 1.0
mode: sync # sync/async
# trainer: Trainer configuration
trainer:
n_gpus_per_node: 8
nnodes: 1
total_epochs: 30
save_freq: 10
test_freq: 5
val_before_train: false
critic_warmup: 0
project_name: my_project
experiment_name: grpo_training
logger: wandb # wandb/tensorboard/console
# dag: DAG configuration
dag:
custom_pipeline_fn: null # Custom Pipeline function path
enable_perf: false
backend_threshold: 32
10.2 Key Parameter Descriptions
-------------------------------
.. list-table:: Key Configuration Parameters
:header-rows: 1
:widths: 30 15 55
* - Parameter
- Default
- Description
* - ``algorithm.adv_estimator``
- grpo
- Advantage estimator (grpo/gae/cpgd/gspo)
* - ``algorithm.workflow_type``
- DEFAULT
- Workflow type (DEFAULT/DAPO/EMBODIED)
* - ``data.train_batch_size``
- 512
- Global training batch size
* - ``actor_rollout_ref.rollout.n``
- 8
- GRPO samples per prompt
* - ``actor_rollout_ref.actor.clip_ratio``
- 0.2
- PPO clipping ratio
* - ``actor_rollout_ref.actor.ppo_epochs``
- 1
- PPO epochs per training step
* - ``actor_rollout_ref.rollout.tensor_model_parallel_size``
- 1
- Rollout TP size
* - ``trainer.save_freq``
- 10
- Checkpoint save frequency (steps)
* - ``trainer.test_freq``
- 5
- Validation frequency (steps)
10.3 How to Add New Configuration Items
---------------------------------------
siiRL uses Python dataclasses for configuration management. Here's how to add new configuration items:
**Step 1: Identify the Configuration Group**
Configuration is organized into the following groups in ``siirl/params/``:
::
siirl/params/
├── __init__.py # Exports all argument classes
├── training_args.py # TrainingArguments, SiiRLArguments (root)
├── model_args.py # ActorArguments, RolloutArguments, AlgorithmArguments, etc.
├── data_args.py # DataArguments
├── dag_args.py # DagArguments
├── profiler_args.py # ProfilerArguments
└── embodied_args.py # EmbodiedArguments
**Step 2: Add a New Field to the Appropriate Dataclass**
Example: Adding a new ``max_retry_count`` field to ``TrainingArguments``:
.. code-block:: python
:caption: siirl/params/training_args.py
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class TrainingArguments:
# Existing fields...
total_epochs: int = field(default=30, metadata={"help": "Total training epochs"})
save_freq: int = field(default=-1, metadata={"help": "Checkpoint frequency"})
# Add your new field here
max_retry_count: int = field(
default=3,
metadata={"help": "Maximum retry count for failed training steps"}
)
**Step 3: Add a New Argument Group (if needed)**
If adding a completely new category, create a new dataclass and register it in ``SiiRLArguments``:
.. code-block:: python
:caption: siirl/params/my_custom_args.py (new file)
from dataclasses import dataclass, field
from typing import Dict, Any
@dataclass
class MyCustomArguments:
"""Custom arguments for new feature."""
enable_feature: bool = field(
default=False,
metadata={"help": "Enable the custom feature"}
)
feature_threshold: float = field(
default=0.5,
metadata={"help": "Threshold for the custom feature"}
)
feature_config: Dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Additional configuration for the feature"}
)
def to_dict(self) -> Dict[str, Any]:
from dataclasses import asdict
return asdict(self)
Then register in ``SiiRLArguments``:
.. code-block:: python
:caption: siirl/params/training_args.py
from siirl.params.my_custom_args import MyCustomArguments
@dataclass
class SiiRLArguments:
data: DataArguments = field(default_factory=DataArguments)
actor_rollout_ref: ActorRolloutRefArguments = field(default_factory=ActorRolloutRefArguments)
# ... existing fields ...
# Add your new argument group
my_custom: MyCustomArguments = field(default_factory=MyCustomArguments)
**Step 4: Export in __init__.py**
.. code-block:: python
:caption: siirl/params/__init__.py
from .my_custom_args import MyCustomArguments
__all__ = [
# ... existing exports ...
"MyCustomArguments",
]
**Step 5: Use in YAML Configuration**
After adding the new fields, you can use them in your YAML configuration:
.. code-block:: yaml
:caption: config.yaml
trainer:
total_epochs: 30
save_freq: 10
max_retry_count: 5 # Your new field
my_custom: # Your new argument group
enable_feature: true
feature_threshold: 0.7
feature_config:
key1: value1
key2: value2
**Step 6: Access in Code**
.. code-block:: python
def my_function(config: SiiRLArguments):
# Access top-level trainer config
max_retry = config.trainer.max_retry_count
# Access your custom argument group
if config.my_custom.enable_feature:
threshold = config.my_custom.feature_threshold
extra_config = config.my_custom.feature_config
**Configuration Hierarchy**:
::
SiiRLArguments (root)
├── data: DataArguments
│ ├── train_files
│ ├── train_batch_size
│ └── ...
├── actor_rollout_ref: ActorRolloutRefArguments
│ ├── model: ModelArguments
│ ├── actor: ActorArguments
│ │ ├── strategy
│ │ ├── clip_ratio
│ │ ├── optim: OptimizerArguments
│ │ └── ...
│ ├── rollout: RolloutArguments
│ └── ref: RefArguments
├── critic: CriticArguments
├── reward_model: RewardModelArguments
├── algorithm: AlgorithmArguments
│ ├── adv_estimator
│ ├── workflow_type
│ └── kl_ctrl: KLCtrlArguments
├── trainer: TrainingArguments
├── custom_reward_function: CustomRewardArguments
├── dag: DagArguments
└── profiler: ProfilerArguments
----
.. _sec11_extension_guide:
11. Extension Guide
===================
11.1 Custom Pipeline
--------------------
Users can define custom Pipelines:
.. code-block:: python
:caption: examples/custom_pipeline_example/custom_pipeline.py
from siirl.execution.dag.pipeline import Pipeline
from siirl.execution.dag.node import NodeType, NodeRole
from siirl.execution.dag.task_graph import TaskGraph
def my_custom_pipeline() -> TaskGraph:
"""Custom training pipeline"""
pipeline = Pipeline("my_custom_pipeline", "My custom RL workflow")
# Add custom nodes
pipeline.add_node(
"rollout_actor",
func="siirl.dag_worker.dagworker:DAGWorker.generate",
deps=[],
node_type=NodeType.MODEL_INFERENCE,
node_role=NodeRole.ROLLOUT
).add_node(
"custom_reward",
func="my_module.custom_reward:compute_custom_reward", # Custom function
deps=["rollout_actor"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.REWARD
).add_node(
"calculate_advantages",
func="siirl.dag_worker.dagworker:DAGWorker.compute_advantage",
deps=["custom_reward"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.ADVANTAGE
).add_node(
"actor_train",
func="siirl.dag_worker.dagworker:DAGWorker.train_actor",
deps=["calculate_advantages"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR
)
return pipeline.build()
Specify in configuration:
.. code-block:: yaml
dag:
custom_pipeline_fn: "my_module.custom_pipeline:my_custom_pipeline"
11.2 Custom Reward Function
---------------------------
.. code-block:: python
:caption: siirl/user_interface/rewards_interface/custom_reward.py
from siirl.dag_worker.data_structures import NodeOutput
from tensordict import TensorDict
def compute_custom_reward(batch: TensorDict, config, **kwargs) -> NodeOutput:
"""Custom Reward computation function"""
# Get generated responses
responses = batch["responses"]
prompts = batch["prompts"]
# Custom reward logic
rewards = []
for prompt, response in zip(prompts, responses):
# Implement your reward function
score = my_scoring_function(prompt, response)
rewards.append(score)
# Convert to token-level rewards
token_level_rewards = torch.zeros_like(batch["attention_mask"])
for i, score in enumerate(rewards):
# Assign sequence-level reward to last token
token_level_rewards[i, -1] = score
batch["token_level_scores"] = token_level_rewards
batch["token_level_rewards"] = token_level_rewards
metrics = {"reward/mean_score": np.mean(rewards)}
return NodeOutput(batch=batch, metrics=metrics)
11.3 Custom Advantage Estimator
-------------------------------
.. code-block:: python
:caption: Registering Custom Advantage Estimator
from siirl.dag_worker.core_algos import register_adv_est
from siirl.execution.scheduler.enums import AdvantageEstimator
@register_adv_est("my_custom_adv") # Or use enum
def compute_my_custom_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
**kwargs
):
"""Custom Advantage estimation"""
# Implement your advantage estimation logic
advantages = ...
returns = ...
return advantages, returns
11.4 Custom Policy Loss
-----------------------
.. code-block:: python
:caption: Registering Custom Policy Loss
from siirl.dag_worker.core_algos import register_policy_loss
@register_policy_loss("my_custom_loss")
def compute_my_custom_policy_loss(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config = None,
rollout_is_weights = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Custom policy loss"""
# Implement your policy loss logic
pg_loss = ...
pg_clipfrac = ...
ppo_kl = ...
pg_clipfrac_lower = ...
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
----
Appendix A: Code File Navigation
================================
::
siirl/
├── main_dag.py # Main entry point
├── dag_worker/ # DAG Worker module
│ ├── dagworker.py # Core Worker class (~1320 lines)
│ ├── core_algos.py # RL algorithm implementations
│ ├── dag_utils.py # Utility functions
│ ├── checkpoint_manager.py # Checkpoint management
│ ├── validator.py # Validation logic
│ ├── metrics_collector.py # Metrics collection
│ └── data_structures.py # Data structure definitions
├── execution/ # Execution engine
│ ├── dag/ # DAG definitions
│ │ ├── __init__.py # Module exports
│ │ ├── task_graph.py # TaskGraph class
│ │ ├── node.py # Node/NodeType/NodeRole/NodeStatus classes
│ │ ├── pipeline.py # Pipeline Builder API
│ │ ├── builtin_pipelines.py # Built-in Pipelines (GRPO/PPO/DAPO/Embodied)
│ │ └── task_loader.py # Graph splitting utilities
│ ├── scheduler/ # Scheduler
│ │ ├── task_scheduler.py # Task scheduling
│ │ ├── process_group_manager.py # Process group management
│ │ ├── launch.py # Ray launcher
│ │ └── enums.py # Enum definitions
│ └── metric_worker/ # Distributed metrics
│ └── metric_worker.py # MetricWorker Actor
├── engine/ # Model execution engine
│ ├── actor/ # Actor Workers
│ ├── critic/ # Critic Workers
│ ├── rollout/ # Rollout Workers (vLLM/SGLang/HF)
│ ├── reward_model/ # Reward Model Workers
│ ├── reward_manager/ # Reward Managers (naive/parallel/dapo/embodied)
│ └── sharding_manager/ # Weight sharding management (FSDP/Megatron)
├── data_coordinator/ # Data coordinator
│ ├── data_buffer.py # DataCoordinator Actor
│ ├── dataloader/ # Distributed DataLoader
│ ├── protocol.py # Data protocol
│ └── sample.py # Sample/SampleInfo
├── user_interface/ # User extension interfaces
│ ├── filter_interface/ # Filtering plugins
│ │ ├── dapo.py # DAPO dynamic sampling
│ │ └── embodied.py # Embodied dynamic sampling
│ └── rewards_interface/ # Custom reward interfaces
├── params/ # Configuration parameters
│ ├── __init__.py # SiiRLArguments
│ ├── parser.py # Configuration parser
│ ├── data_args.py # Data parameters
│ ├── model_args.py # Model parameters
│ └── training_args.py # Training parameters
└── utils/ # Utilities
├── checkpoint/ # Checkpoint utilities
├── logger/ # Logging utilities
├── model_utils/ # Model utilities
└── reward_score/ # Reward computation
----
Summary
=======
This document provides a comprehensive guide to siiRL's architecture implementation, including:
1. **Architecture Overview**: siiRL's position in distributed RL systems and core advantages
2. **DistFlow Design Philosophy**: Fully distributed, multi-controller paradigm design
3. **Program Entry**: main_dag.py and MainRunner startup flow
4. **DAG Planner**: Pipeline API, TaskGraph, TaskScheduler implementation
5. **DAG Worker**: Core execution unit initialization, training loop, node execution
6. **Data Coordinator**: Distributed data management and length balancing algorithm
7. **Engine**: Actor/Critic/Rollout/Reference/Reward Worker implementations
8. **Core Algorithms**: Advantage estimators, Policy Loss function implementations
9. **Execution Flow**: Complete GRPO/PPO training flows
10. **Configuration**: Key configuration parameters explained
11. **Extension Guide**: Custom Pipeline, Reward, Advantage, Policy Loss
By reading this document, readers should gain a deep understanding of siiRL's design philosophy and implementation details, providing a solid foundation for future development, optimization, and extension work.
**References**:
- siiRL Paper: `DistFlow: A Fully Distributed RL Framework for Scalable and Efficient LLM Post-Training `__
- Official Documentation: `https://siirl.readthedocs.io/ `__
- GitHub Repository: `https://github.com/sii-research/siiRL `__
================================================
FILE: docs/programming_guide/srpo_code_explained.rst
================================================
SRPO Code Implementation Explained
==================================
This document provides a comprehensive guide to understanding the SRPO (Self-Referential Policy Optimization) algorithm implementation in siiRL. SRPO is designed for training Vision-Language-Action (VLA) models in embodied AI scenarios.
.. note::
**Paper Reference**: `SRPO: Self-Referential Policy Optimization for Vision-Language-Action Models `_
Overview: What is SRPO?
-----------------------
**Self-Referential Policy Optimization (SRPO) for Vision-Language-Action Models** is a novel VLA-RL framework. SRPO eliminates the need for external demonstrations or manual reward engineering by leveraging successful trajectories generated by the model within the current training batch as self-references. This enables us to assign progress-based rewards to failed attempts.
A core innovation is the use of **latent world representations** (V-JEPA) to robustly measure behavioral progress. Rather than relying on raw pixels or requiring domain-specific fine-tuning, we utilize compressed, transferable encodings from a world model's latent space. These representations naturally capture progress patterns across environments, making trajectory comparison accurate and generalizable.
Empirical evaluation on the LIBERO benchmark demonstrates SRPO's efficiency and effectiveness. Starting from a supervised baseline with a 48.9% success rate, SRPO achieves a 99.2% success rate on novel states within only 200 RL steps, representing a 103% relative improvement without any additional supervision. Furthermore, SRPO shows significant robustness on the LIBERO-Plus benchmark, achieving a 167% performance gain.
**In siiRL, SRPO is implemented as the** ``embodied_srpo_pipeline`` **+** ``GRPO`` **advantage estimator.**
Code Architecture Overview
--------------------------
.. code-block:: text
siiRL/
├── siirl/
│ ├── execution/
│ │ └── dag/
│ │ └── builtin_pipelines.py # embodied_srpo_pipeline() definition
│ ├── user_interface/
│ │ └── filter_interface/
│ │ └── embodied.py # embodied_local_rank_sampling()
│ ├── engine/
│ │ ├── rollout/
│ │ │ └── embodied_rollout.py # EmbodiedHFRollout class
│ │ └── actor/
│ │ └── embodied_actor.py # RobDataParallelPPOActor class
│ ├── dag_worker/
│ │ └── core_algos.py # GRPO advantage & PPO loss
│ ├── environment/
│ │ └── embodied/
│ │ └── adapters/ # LIBERO environment adapter
│ └── utils/
│ ├── reward_score/
│ │ └── embodied.py # compute_embodied_reward()
│ └── embodied/
│ └── video_emb.py # VideoEmbeddingModel (V-JEPA)
└── examples/
└── embodied_srpo_trainer/
└── run_openvla_oft_libero_*.sh # Training scripts
Training Pipeline Definition
----------------------------
The SRPO training pipeline is defined in ``siirl/execution/dag/builtin_pipelines.py`` using the Python Pipeline API:
.. code-block:: python
:caption: siirl/execution/dag/builtin_pipelines.py - embodied_srpo_pipeline()
def embodied_srpo_pipeline() -> TaskGraph:
"""
Embodied AI GRPO training pipeline with data filtering and VJEPA-based reward computation.
Workflow:
1. rollout_actor: Environment rollout with embodied AI agent
2. dynaminc_sampling: Data verification and filtering
3. compute_reward: VJEPA-based reward computation
4. calculate_advantages: Calculate advantages (GRPO group-based)
5. actor_old_log_prob: Compute old actor log probabilities (forward only)
6. reference_log_prob: Compute reference model log probabilities
7. actor_train: Actor training with GRPO
"""
pipeline = Pipeline(
"embodied_grpo_training_pipeline",
"Embodied AI GRPO training workflow with data filtering and VJEPA-based reward computation."
)
pipeline.add_node(
"rollout_actor",
func="siirl.dag_worker.dagworker:DAGWorker.generate",
deps=[],
node_type=NodeType.MODEL_INFERENCE,
node_role=NodeRole.ROLLOUT
).add_node(
"dynaminc_sampling",
func="siirl.user_interface.filter_interface.embodied.embodied_local_rank_sampling",
deps=["rollout_actor"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.DYNAMIC_SAMPLING
).add_node(
"compute_reward",
func="siirl.dag_worker.dagworker:DAGWorker.compute_reward",
deps=["dynaminc_sampling"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.REWARD
).add_node(
"calculate_advantages",
func="siirl.dag_worker.dagworker:DAGWorker.compute_advantage",
deps=["compute_reward"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.ADVANTAGE
).add_node(
"actor_old_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob",
deps=["calculate_advantages"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR,
only_forward_compute=True
).add_node(
"reference_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob",
deps=["actor_old_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.REFERENCE
).add_node(
"actor_train",
func="siirl.dag_worker.dagworker:DAGWorker.train_actor",
deps=["reference_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR
)
return pipeline.build()
Data Flow Diagram
~~~~~~~~~~~~~~~~~
.. code-block:: text
SRPO Training Pipeline Data Flow
==============================================================================
DataLoader (task_id, trial_id)
|
v
+---------------------+
| rollout_actor | EmbodiedHFRollout.generate_sequences()
| (MODEL_INFERENCE) | -> VLA model + LIBERO environment interaction
+----------+----------+
| Output: {responses, input_ids, attention_mask, pixel_values,
| complete, finish_step, vjepa_embedding, task_file_name}
v
+---------------------+
| dynamic_sampling | embodied_local_rank_sampling()
| (COMPUTE) | -> verify() + _filter_batch()
+----------+----------+ Filter by accuracy bounds & truncation
| Output: filtered batch (samples with 0.1 <= acc <= 0.9)
v
+---------------------+
| compute_reward | compute_embodied_reward()
| (COMPUTE) | -> VJEPA-based reward shaping
+----------+----------+ Success: reward=1.0, Failure: reward=sigmoid(distance)
| Output: + {token_level_scores, token_level_rewards}
v
+---------------------+
| calculate_advantages| compute_grpo_outcome_advantage()
| (COMPUTE) | -> Group by prompt, normalize (score - mean) / std
+----------+----------+
| Output: + {advantages, returns}
v
+---------------------+
| actor_old_log_prob | RobDataParallelPPOActor.compute_log_prob()
| (MODEL_TRAIN) | -> Forward only, no gradient
| only_forward=True |
+----------+----------+
| Output: + {old_log_probs}
v
+---------------------+
| reference_log_prob | Reference model forward pass
| (MODEL_TRAIN) |
+----------+----------+
| Output: + {ref_log_prob}
v
+---------------------+
| actor_train | RobDataParallelPPOActor.update_policy()
| (MODEL_TRAIN) | -> compute_policy_loss_vanilla() (PPO clipped loss)
+---------------------+
|
| Metrics: {pg_loss, pg_clipfrac, ppo_kl, grad_norm}
v
+---------------------+
| sync_weights | ShardingManager (if needed)
+---------------------+
Core Components Deep Dive
-------------------------
1. Rollout: Environment Interaction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
**File**: ``siirl/engine/rollout/embodied_rollout.py``
**Class**: ``EmbodiedHFRollout``
This is the core component that orchestrates the interaction between the VLA model and the simulation environment (LIBERO). It handles the complete episode generation process including action prediction, environment stepping, and visual embedding extraction.
Class Initialization
^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
class EmbodiedHFRollout(BaseRollout):
def __init__(self, module: nn.Module, config: ActorRolloutRefArguments):
self.model = module # VLA model (e.g., OpenVLA-OFT)
self.config = config
# Initialize V-JEPA embedding model for reward computation
self.embedding_model = VideoEmbeddingModel(
model_path=config.embodied.video_embedding_model_path,
img_size=config.embodied.embedding_img_size,
enable_fp16=config.embodied.embedding_enable_fp16
)
# Initialize LIBERO environment adapter with parallel environments
self.num_workers = config.embodied.env.num_envs # e.g., 16 parallel envs
self.adapter = LIBEROAdapter(
env_name=config.embodied.env.env_name, # e.g., "libero_goal"
num_envs=self.num_workers,
max_steps=config.embodied.env.max_steps, # e.g., 512
num_steps_wait=config.embodied.env.num_steps_wait,
model_family=config.embodied.env.model_family,
gpu_ids=[self._rank % self._num_gpus_per_node]
)
Main Entry Point: generate_sequences()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def generate_sequences(self, prompts):
"""
Main entry point for generating sequences.
Splits large batches into chunks that fit the number of parallel workers.
"""
total_batch_size = prompts.batch_size[0]
n_samples = prompts['n_samples'] if 'n_samples' in prompts else 1
# Each prompt needs n_samples trajectories
batch_size_per_chunk = self.num_workers // n_samples
num_chunks = (total_batch_size + batch_size_per_chunk - 1) // batch_size_per_chunk
all_chunk_outputs = []
for i in range(num_chunks):
chunk_prompts = prompts[start_idx:end_idx]
chunk_output = self._generate_chunk_rollout(chunk_prompts)
all_chunk_outputs.append(chunk_output)
return torch.cat(all_chunk_outputs)
Episode Generation Loop: _generate_chunk_rollout()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This is the heart of the embodied rollout - a step-by-step interaction loop between the VLA model and the environment.
.. code-block:: python
def _generate_chunk_rollout(self, prompts):
"""Generate complete episodes for a chunk of tasks."""
task_id = prompts['task_id']
trial_id = prompts['trial_id']
max_steps = self.config.embodied.env.max_steps
chunk_size = task_id.size(0)
# Step 1: Reset all parallel environments
init_data_list = self.adapter._blocking_reset(
task_ids=task_id.reshape(-1).cpu().numpy().tolist(),
trial_ids=trial_id.reshape(-1).cpu().numpy().tolist(),
)
# Collect initial observations
inputs = [self._obs_to_input(init_data['obs']) for init_data in init_data_list]
task_descriptions = [init_data["task_description"] for init_data in init_data_list]
task_records = [{"active": d['active'], "complete": d['complete'],
"finish_step": d['finish_step'], "task_file_name": d['task_file_name']}
for d in init_data_list]
# Step 2: Main interaction loop (up to max_steps)
step = 0
vla_history = [] # Store all step data for training
while step < max_steps:
active_indices = [i for i, r in enumerate(task_records) if r['active']]
# Step 2a: Process observations into VLA input format
vla_input = self.process_input(inputs, task_descriptions)
# Step 2b: VLA model predicts actions
vla_output = self._generate_one_step(vla_input)
actions = vla_output["action"]
# Store step data for later training
vla_history.append({
"responses": vla_output["responses"],
"input_ids": vla_output["input_ids"],
"attention_mask": vla_output["attention_mask"],
"pixel_values": vla_output["pixel_values"],
"action": actions,
"step": step
})
# Step 2c: Execute actions in environment
step_results_list = self.adapter._blocking_step({
"indices": active_indices,
"actions": actions,
})
# Step 2d: Update observations and task status
for idx in active_indices:
result = step_results_list[idx]
inputs[idx] = self._obs_to_input(result['obs'])
task_records[idx]['active'] = result['active']
task_records[idx]['complete'] = result['complete']
task_records[idx]['finish_step'] = result['finish_step']
step += self.config.embodied.action_chunks_len # e.g., += 8
# Step 3: Post-processing - Stack history and compute embeddings
batch = {}
for k in ["responses", "input_ids", "attention_mask", "pixel_values"]:
batch[k] = torch.stack([h[k] for h in vla_history], dim=1)
batch["complete"] = torch.tensor([r["complete"] for r in task_records])
batch["finish_step"] = torch.tensor([r["finish_step"] for r in task_records])
# Step 4: Extract V-JEPA embeddings for reward computation
batch_names, batch_frames = zip(*[(r['task_file_name'], all_video[r['task_file_name']])
for r in task_records])
vjepa_embeddings = self.embedding_model.get_embeddings(batch_names, batch_frames)
batch["vjepa_embedding"] = torch.tensor(np.array(vjepa_embeddings))
return TensorDict(batch, batch_size=chunk_size)
Single-Step Action Generation: _generate_one_step()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
@torch.no_grad()
def _generate_one_step(self, prompts: dict):
"""Generate one action chunk from VLA model."""
if self.config.embodied.embodied_type == "openvla-oft":
# OpenVLA-OFT: Action Flow Transformer variant
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
actions, response = self.model.generate_action_verl(
input_ids=idx,
pixel_values=pixel_values,
attention_mask=attention_mask,
do_sample=do_sample,
unnorm_key=self.config.embodied.unnorm_key,
temperature=temperature,
)
# response shape: (batch_size, action_chunks_len * action_token_len)
elif self.config.embodied.embodied_type == "openvla":
# Standard OpenVLA: Autoregressive token generation
output = self.model.generate(
input_ids=idx,
pixel_values=pixel_values,
attention_mask=attention_mask,
do_sample=do_sample,
max_new_tokens=response_length,
temperature=temperature,
)
# Decode action tokens to continuous actions
predicted_action_token_ids = output.sequences[:, prompt_length:]
discretized_actions = self.model.vocab_size - predicted_action_token_ids
normalized_actions = self.model.bin_centers[discretized_actions]
return {
'responses': response,
'input_ids': idx,
'attention_mask': attention_mask,
'pixel_values': pixel_values,
'action': actions,
}
**Key Output Fields**:
.. list-table::
:header-rows: 1
:widths: 25 30 45
* - Field
- Shape
- Description
* - ``responses``
- ``(B, traj_len, action_token_len)``
- Action tokens (e.g., 7-dim: xyz + quat + gripper)
* - ``complete``
- ``(B,)``
- Boolean: task success flag
* - ``finish_step``
- ``(B,)``
- Integer: episode termination step
* - ``vjepa_embedding``
- ``(B, embed_dim)``
- V-JEPA visual features for reward computation
2. Data Filtering (Dynamic Sampling)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
**File**: ``siirl/user_interface/filter_interface/embodied.py``
**Function**: ``embodied_local_rank_sampling()``
This step filters out "too easy" or "too hard" prompts based on the success rate within each group.
.. code-block:: python
def embodied_local_rank_sampling(
config: SiiRLArguments,
batch: TensorDict,
**kwargs: Any,
) -> NodeOutput:
"""
Performs verification, metric collection, and optional filtering on a batch.
"""
# Step 1: Verify the entire batch to get scores and enrich it with an 'acc' tensor.
_, reward_metrics, format_metrics, reward_format_metrics = verify(batch)
# Step 2: Conditionally filter the batch based on accuracy and truncation
embodied_sampling = config.algorithm.embodied_sampling
if embodied_sampling.filter_accuracy or embodied_sampling.filter_truncated:
n_samples = config.actor_rollout_ref.rollout.n
processed_batch = _filter_batch(batch, n_samples, config)
else:
processed_batch = batch
return NodeOutput(batch=processed_batch, metrics=sample_metrics)
def _filter_batch(batch: TensorDict, n_samples: int, config: SiiRLArguments) -> TensorDict:
"""
Filters a batch based on accuracy and truncation criteria.
Filtering is performed at the prompt level.
"""
num_prompts = len(batch) // n_samples
# --- 1. Accuracy Filtering ---
if config.algorithm.embodied_sampling.filter_accuracy:
# Reshape flat accuracy tensor into (num_prompts, n_samples)
acc_matrix = batch["acc"].reshape(num_prompts, n_samples)
# Calculate mean accuracy for each prompt
prompt_mean_acc = acc_matrix.mean(dim=-1)
# Create a boolean mask for prompts within the desired accuracy bounds
accuracy_lower_bound = config.algorithm.embodied_sampling.accuracy_lower_bound
accuracy_upper_bound = config.algorithm.embodied_sampling.accuracy_upper_bound
acc_mask = (prompt_mean_acc >= accuracy_lower_bound) & (prompt_mean_acc <= accuracy_upper_bound)
else:
acc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)
# --- 2. Truncation Filtering ---
if config.algorithm.embodied_sampling.filter_truncated:
finish_steps = batch["finish_step"].reshape(num_prompts, n_samples)
max_steps = config.actor_rollout_ref.embodied.env.max_steps
# A prompt is considered truncated if *any* of its samples reached max steps
has_truncated = (finish_steps >= max_steps).any(dim=-1)
trunc_mask = ~has_truncated
else:
trunc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)
# --- 3. Combine Masks and Apply Filter ---
combined_mask = acc_mask & trunc_mask
final_mask = combined_mask.repeat_interleave(n_samples)
filtered_batch = select_idxs(batch, final_mask)
return filtered_batch
**Why Filter?**
- **Too easy (acc > 0.9)**: All samples succeed → zero variance → zero advantage → no learning signal.
- **Too hard (acc < 0.1)**: All samples fail → similar issue.
- **Sweet spot (0.1 ≤ acc ≤ 0.9)**: Diverse outcomes → meaningful advantage estimates.
3. Reward Computation (VJEPA-based)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
**File**: ``siirl/utils/reward_score/embodied.py``
**Function**: ``compute_embodied_reward()``
This is a key innovation of SRPO: using visual similarity to compute dense rewards for failed trajectories.
.. code-block:: python
def compute_embodied_reward(
batch_data: TensorDict,
**kwargs: Any,
) -> List[Dict[str, Any]]:
"""
Computes rewards based on VJEPA embeddings and task completion status.
Reward Formula:
- Success: reward = 1.0
- Failure: reward = sigmoid(distance_to_success_cluster) ∈ [0, 0.6]
"""
# --- Step 1: Data Extraction and Pre-filtering ---
batch_size = batch_data["responses"].size(0)
completes = np.array(batch_data["complete"].tolist())
embeddings = batch_data["vjepa_embedding"].cpu().numpy()
task_file_names = _tensor_to_str_list(batch_data["task_file_name"])
# Pre-filtering: Identify invalid samples (all-zero embeddings)
zero_embedding_mask = np.all(embeddings == 0, axis=1)
valid_indices = np.where(~zero_embedding_mask)[0]
# --- Step 2: Initialize rewards ---
final_rewards = np.zeros(batch_size, dtype=float)
task_names = [_extract_task_name(name) for name in task_file_names]
# Group valid samples by task name
task_to_valid_indices = {}
for idx in valid_indices:
task_name = task_names[idx]
task_to_valid_indices.setdefault(task_name, []).append(idx)
# --- Step 3: Process each task group ---
for task_name, indices in task_to_valid_indices.items():
indices = np.array(indices)
task_completes = completes[indices]
success_indices = indices[task_completes]
fail_indices = indices[~task_completes]
# Success trajectories get reward = 1.0
final_rewards[success_indices] = 1.0
if len(success_indices) == 0 or len(fail_indices) == 0:
continue
# a. Cluster successful embeddings using DBSCAN
succ_embeddings = embeddings[success_indices]
scaler = StandardScaler()
scaled_succ_embeddings = scaler.fit_transform(succ_embeddings)
clustering = DBSCAN(eps=0.5, min_samples=2).fit(scaled_succ_embeddings)
cluster_centers = []
for label in set(clustering.labels_) - {-1}:
cluster_points = scaled_succ_embeddings[clustering.labels_ == label]
center = scaler.inverse_transform(cluster_points.mean(axis=0, keepdims=True)).flatten()
cluster_centers.append(center)
if not cluster_centers:
cluster_centers = [succ_embeddings.mean(axis=0)]
cluster_centers = np.array(cluster_centers)
# b. Compute distance from failed trajectories to nearest success cluster
fail_embeddings = embeddings[fail_indices]
distance_matrix = cdist(fail_embeddings, cluster_centers, "euclidean")
min_distances = distance_matrix.min(axis=1)
# c. Map distance to reward via sigmoid
max_dist, min_dist = min_distances.max(), min_distances.min()
dist_range = max_dist - min_dist
if dist_range < 1e-6:
normalized_dists = np.full_like(min_distances, 0.5)
else:
normalized_dists = (min_distances - min_dist) / dist_range
sigmoid_steepness = 10.0
sigmoid_offset = 0.5
sigmoid_inputs = sigmoid_steepness * (sigmoid_offset - normalized_dists)
reward_values = 0.6 * special.expit(sigmoid_inputs)
final_rewards[fail_indices] = reward_values
return [{"score": final_rewards[i]} for i in range(batch_size)]
**Reward Visualization**:
.. code-block:: text
Reward
^
1.0| ●●● (Success)
|
0.6| ───────────────────── (Max for failure)
| ╱
| ╱ Sigmoid curve
| ╱
0.0|───╱────────────────────▶ Distance to Success
Near Far
**Intuition**: Failed trajectories that are "visually similar" to successful ones (low distance) receive higher rewards, encouraging the policy to explore in promising directions.
4. Advantage Calculation (GRPO)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
**File**: ``siirl/dag_worker/core_algos.py``
**Function**: ``compute_grpo_outcome_advantage()``
GRPO computes advantages using group-relative normalization, eliminating the need for a Critic network.
.. code-block:: python
@register_adv_est(AdvantageEstimator.GRPO)
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, # (B, response_length)
response_mask: torch.Tensor, # (B, response_length)
index: np.ndarray, # (B,) - prompt index for grouping
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgorithmArguments] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
GRPO Advantage = (reward - group_mean) / group_std
This is the "Self-Referential" part: the baseline is computed from
the policy's own samples, not from a separate Value network.
"""
# Sum rewards across response tokens to get scalar reward per sample
scores = token_level_rewards.sum(dim=-1) # (B,)
# Group samples by prompt index
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])
id2score[idx_key].append(scores[i])
# Compute group statistics
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
# Normalize: advantage = (score - mean) / std
for i in range(bsz):
idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])
if norm_adv_by_std_in_grpo:
scores[i] = (scores[i] - id2mean[idx_key]) / (id2std[idx_key] + epsilon)
else:
scores[i] = scores[i] - id2mean[idx_key] # Dr.GRPO variant
# Broadcast to token level
scores = scores.unsqueeze(-1) * response_mask
return scores, scores # (advantages, returns)
**Embodied-specific handling in compute_advantage()**:
.. code-block:: python
def compute_advantage(data: TensorDict, adv_estimator, ...):
if adv_estimator == AdvantageEstimator.GRPO:
if "finish_step" in data and data["responses"].ndim == 3:
# Embodied scenario: compute mask based on finish_step
responses = data["responses"]
batch_size = responses.size(0)
response_length = responses.size(1) * responses.size(2) # traj_len * action_token_len
action_token_len = responses.size(2)
finish_step = data['finish_step'] * action_token_len
steps = torch.arange(response_length, device=responses.device)
steps_expanded = steps.unsqueeze(0).expand(batch_size, -1)
grpo_calculation_mask = steps_expanded < finish_step.unsqueeze(1)
else:
# NLP scenario: use attention_mask-based response_mask
grpo_calculation_mask = data["response_mask"]
advantages, returns = compute_grpo_outcome_advantage(
token_level_rewards=data["token_level_rewards"],
response_mask=grpo_calculation_mask,
index=data["uid"],
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
5. Policy Update (PPO Loss)
~~~~~~~~~~~~~~~~~~~~~~~~~~~
**File**: ``siirl/engine/actor/embodied_actor.py``
**Class**: ``RobDataParallelPPOActor``
**Method**: ``update_policy()``
The actor update uses the standard PPO clipped objective with GRPO advantages.
.. code-block:: python
def update_policy(self, data: TensorDict):
self.actor_module.train()
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
temperature = data['temperature']
select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values',
'old_log_probs', 'advantages', "finish_step"]
batch = data.select(*select_keys)
dataloader = batch.split(self.config.ppo_mini_batch_size)
metrics = {}
for batch_idx, data in enumerate(dataloader):
mini_batch = data
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
self.actor_optimizer.zero_grad()
for test_idx, data in enumerate(micro_batches):
data = data.cuda()
responses = data['responses']
response_length = responses.size(1) * responses.size(2)
# Build response mask from finish_step
finish_step = data['finish_step'] * self.config.action_token_len
steps = torch.arange(response_length, device=responses.device)
steps_expanded = steps.unsqueeze(0).expand(responses.size(0), -1)
response_mask = steps_expanded < finish_step.unsqueeze(1)
old_log_prob = data['old_log_probs']
advantages = data['advantages']
# Split trajectory into mini-batches for memory efficiency
traj_len = responses.size(1)
traj_split_num = int(traj_len / self.config.traj_mini_batch_size)
for i in range(0, traj_len, int(traj_len / traj_split_num)):
# Forward pass to get current log probs
entropy, log_prob = self._forward_micro_batch_update(
input_ids=input_ids[i:i+chunk_size],
attention_mask=attention_mask[i:i+chunk_size],
pixel_values=pixel_values[i:i+chunk_size],
responses=responses[i:i+chunk_size],
temperature=temperature
)
# Compute PPO clipped loss
pg_loss, pg_clipfrac, ppo_kl, _ = core_algos.compute_policy_loss_vanilla(
old_log_prob=old_log_prob_tmp,
log_prob=log_prob,
advantages=advantages_tmp,
response_mask=response_mask_tmp,
config=self.config
)
loss = pg_loss / self.gradient_accumulation
loss.backward()
grad_norm = self._optimizer_step()
return metrics
**PPO Loss Function** (from ``core_algos.py``):
.. code-block:: python
@register_policy_loss("vanilla")
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
config: Optional[ActorArguments] = None,
...
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
L^CLIP(θ) = E[min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t)]
where r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
"""
clip_ratio = config.clip_ratio
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
# Importance ratio
negative_approx_kl = log_prob - old_log_prob
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) # stability
ratio = torch.exp(negative_approx_kl)
ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)
# Clipped objective
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
# Dual-clip for negative advantages
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
Key Configuration Parameters
----------------------------
.. list-table::
:header-rows: 1
:widths: 35 30 35
* - Parameter
- Location
- Description
* - ``algorithm.adv_estimator``
- Training config
- Set to ``grpo`` for SRPO
* - ``actor_rollout_ref.rollout.n``
- Training script
- Group size (samples per prompt)
* - ``algorithm.embodied_sampling.filter_accuracy``
- Training script
- Enable accuracy-based filtering
* - ``algorithm.embodied_sampling.accuracy_lower_bound``
- Training script
- Min success rate (default: 0.1)
* - ``algorithm.embodied_sampling.accuracy_upper_bound``
- Training script
- Max success rate (default: 0.9)
* - ``algorithm.embodied_sampling.filter_truncated``
- Training script
- Filter truncated episodes
* - ``actor_rollout_ref.embodied.video_embedding_model_path``
- Training script
- Path to V-JEPA model
* - ``actor_rollout_ref.embodied.env.num_envs``
- Config
- Number of parallel environments
* - ``actor_rollout_ref.embodied.env.max_steps``
- Config
- Maximum steps per episode
* - ``actor_rollout_ref.embodied.action_chunks_len``
- Config
- Actions per VLA forward pass
Quick Reference: File Locations
-------------------------------
.. list-table::
:header-rows: 1
:widths: 30 70
* - Component
- File Path
* - Training Entry
- ``siirl/main_dag.py``
* - **Pipeline Definition**
- ``siirl/execution/dag/builtin_pipelines.py``
* - **Embodied Rollout**
- ``siirl/engine/rollout/embodied_rollout.py``
* - Environment Adapter
- ``siirl/environment/embodied/adapters/``
* - V-JEPA Embedding
- ``siirl/utils/embodied/video_emb.py``
* - **Data Filtering**
- ``siirl/user_interface/filter_interface/embodied.py``
* - **VJEPA Reward**
- ``siirl/utils/reward_score/embodied.py``
* - **GRPO Advantage**
- ``siirl/dag_worker/core_algos.py``
* - **VLA Actor**
- ``siirl/engine/actor/embodied_actor.py``
* - Example Scripts
- ``examples/embodied_srpo_trainer/run_openvla_oft_*.sh``
References
----------
1. SRPO Paper: `Self-Referential Policy Optimization for Vision-Language-Action Models `_
2. V-JEPA: `Video Joint Embedding Predictive Architecture 2 `_
================================================
FILE: docs/requirements-docs.txt
================================================
# markdown support
recommonmark
myst_parser
# markdown table support
sphinx-markdown-tables
# theme default rtd
# crate-docs-theme
sphinx-rtd-theme
# pin tokenizers version to avoid env_logger version req
tokenizers==0.21
================================================
FILE: docs/start/install.rst
================================================
Installation
============
siiRL provides three primary installation methods. We **strongly recommend** using the Docker image for the most reliable and hassle-free experience.
* :ref:`Method 1: Install from Docker Image (Recommended) `
* :ref:`Method 2: Install from PyPI (pip) `
* :ref:`Method 3: Install from Source (Custom Environment) `
Requirements
------------
- **Python**: Version >= 3.10
- **CUDA**: Version >= 12.1
Currently, siiRL supports the following configurations are available:
- **FSDP** for training.
- **SGLang** and **vLLM** for rollout generation.
.. _install-docker:
Method 1: Install from docker image
------------------------------------
The stable image is ``siiai/siirl-base:vllm0.8.5.post1-sglang0.4.6.post5-cu124``. This images contains the latest version of inference and training framework and its dependencies.
.. _install-pip:
Method 2: Install from PIP
---------------------------
We provide prebuilt python wheels for Linux. Install siiRL with the following command:
.. code:: bash
# Install siiRL with vLLM
pip install siirl[vllm]
# Then, install required high-performance dependencies for siiRL
pip install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/
pip install flash-attn==2.7.3 --no-build-isolation
.. _install-source:
Method 3: Install from custom environment
---------------------------------------------
We recommend to use docker images for convenience. However, if your environment is not compatible with the docker image, you can also install siirl in a python environment.
Install dependencies
::::::::::::::::::::
1. First of all, to manage environment, we recommend using conda:
.. code:: bash
conda create -n siirl python==3.10
conda activate siirl
2. Install python packages
.. note::
The following commands are an example for an environment with CUDA 12.4.
If you are using a different CUDA version, you must adjust the package versions and index URLs accordingly, especially for torch, flashinfer, and flash-attn.
.. code:: bash
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
pip install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/
pip install flash-attn==2.7.3 --no-build-isolation
pip install accelerate codetiming datasets dill hydra-core pandas wandb loguru tensorboard qwen_vl_utils
pip install 'ray[default]>=2.47.1'
pip install opentelemetry-exporter-prometheus==0.47b0
3. Then, execute the following commands to install vLLM and SGLang:
.. code:: bash
pip install vllm==0.8.5.post1
Install siirl
::::::::::::::
For installing the latest version of siirl, the best way is to clone and
install it from source. Then you can modify our code to customize your
own post-training jobs.
.. code:: bash
git clone https://github.com/sii-research/siiRL.git
cd siirl
pip install -e .
================================================
FILE: docs/start/quickstart.rst
================================================
.. _quickstart:
=========================================================
Quickstart: GRPO training on GSM8K dataset
=========================================================
Post-train a LLM using GSM8K dataset.
Introduction
------------
.. _hf_dataset_gsm8k: https://huggingface.co/datasets/gsm8k
In this example, we train an LLM to tackle the `GSM8k `_ task with function-based rewards.
Prerequisite:
- the latest version of ``siiRL`` and its dependencies installed following the installation guide. Using the docker image is recommended.
- a GPU with at least 24 GB HBM
Dataset Introduction
--------------------
GSM8k is a math problem dataset. The prompt is an elementary school
problem. The LLM model is asked to solve the math problem. Below is an example:
Prompt
Katy makes coffee using teaspoons of sugar and cups of water in the
ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups
of water, calculate the number of teaspoonfuls of sugar she used.
Solution
The total ratio representing the ingredients she used to make the
coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the
number of teaspoons she used is 7/20, she used 7/20\ *120 =
<<7/20*\ 120=42>>42 #### 42
Step 1: Prepare the dataset
----------------------------
We preprocess the dataset in parquet format so that (1) it contains necessary fields for computing RL rewards and (2) is faster to read.
.. code-block:: bash
python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k
Step 2: Download a model for post-training
-------------------------------------------
In this example, we start with the ``Qwen2.5-0.5B-Instruct`` model.
.. code-block:: bash
python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')"
Step 3: Perform GRPO training with the instruct model
----------------------------------------------------------------------
**Reward Model/Function**
We use a pre-defined rule-based reward model. We force the model to produce a final
answer following 4 “#” as shown in the solution. We extract the final
answer from both the solution and model's output using regular
expression matching. We assign a reward of 1 to correct
answer, 0.0 to incorrect answer and 0 to no answer.
For more details, please refer to `siirl/utils/reward_score/gsm8k.py `_.
**Training Script**
Now let's run GRPO training with the dataset and model above. [1]_
Set the ``data.train_files`` ,\ ``data.val_files``, ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on your dataset and model names or paths.
.. code-block:: bash
python3 -m siirl.main_dag \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=128 \
data.max_prompt_length=2048 \
data.max_response_length=4096 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.shuffle=False \
actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_fused_kernels=False \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.grad_clip=0.5 \
actor_rollout_ref.actor.clip_ratio=0.2 \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.max_model_len=8192 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','tensorboard'] \
trainer.project_name=siirl_qwen2.5_0.5b_grpo \
trainer.experiment_name=siirl_qwen2.5_0.5b_grpo_toy \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
trainer.save_freq=200 \
trainer.test_freq=10 \
trainer.total_epochs=30 \
trainer.resume_mode=auto \
trainer.max_actor_ckpt_to_keep=1 \
trainer.default_local_dir=ckpts/qwen2.5_0.5b/grpo/ \
trainer.val_before_train=True 2>&1 | tee verl_demo.log
You are expected to see the following logs, indicating training in progress. The key metric ``val/test_score/openai/gsm8k`` is computed every ``trainer.test_freq`` steps:
.. code-block:: bash
step:1 - training/epoch:1.000 - training/global_step:0.000 - training/rollout_probs_diff_max:0.373 - training/rollout_probs_diff_mean:0.004 - training/rollout_probs_diff_std:0.009 - actor/entropy_loss:0.438 - actor/grad_norm:0.221 - actor/lr:0.000 - actor/pg_clipfrac:0.000 - actor/pg_clipfrac_lower:0.000 - actor/pg_loss:0.003 - actor/ppo_kl:-0.000 - critic/advantages/max:1.789 - critic/advantages/mean:-0.002 - critic/advantages/min:-0.730 - critic/returns/max:1.789 - critic/returns/mean:-0.002 - critic/returns/min:-0.730 - critic/rewards/max:1.000 - critic/rewards/mean:0.013 - critic/rewards/min:0.000 - critic/score/max:1.000 - critic/score/mean:0.013 - critic/score/min:0.000 - perf/cpu_mem_used_gb:11.576 - perf/cpu_memory_used_gb:125.440 - perf/delta_time/actor:72.260 - perf/delta_time/actor_log_prob:10.829 - perf/delta_time/advantage:0.039 - perf/delta_time/compute_core_metrics:0.020 - perf/delta_time/data_loading:1.030 - perf/delta_time/get_data_from_buffer:0.001 - perf/delta_time/get_entry_node:0.000 - perf/delta_time/get_intern_data_actor_old_log_prob:0.000 - perf/delta_time/get_intern_data_actor_train:0.000 - perf/delta_time/get_intern_data_calculate_advantages:0.000 - perf/delta_time/get_intern_data_function_reward:0.000 - perf/delta_time/get_intern_data_reference_log_prob:0.000 - perf/delta_time/get_next_node:0.000 - perf/delta_time/graph_execution:128.358 - perf/delta_time/graph_loop_management:0.001 - perf/delta_time/graph_output_handling:0.002 - perf/delta_time/put_data_to_buffer:0.001 - perf/delta_time/put_intern_data_actor_old_log_prob:0.000 - perf/delta_time/put_intern_data_actor_train:0.000 - perf/delta_time/put_intern_data_calculate_advantages:0.000 - perf/delta_time/put_intern_data_function_reward:0.000 - perf/delta_time/put_intern_data_reference_log_prob:0.000 - perf/delta_time/reduce_metrics:0.036 - perf/delta_time/ref:28.170 - perf/delta_time/reference:28.172 - perf/delta_time/reset_data_buffer:0.038 - perf/delta_time/reset_intern_data_buffer:0.000 - perf/delta_time/reward:0.255 - perf/delta_time/rollout:16.797 - perf/delta_time/step:129.426 - perf/delta_time/step_barrier:0.001 - perf/max_mem_alloc_gb:34.832 - perf/max_mem_rsvd_gb:39.678 - perf/max_memory_allocated_gb:34.832 - perf/max_memory_reserved_gb:39.678 - perf/mfu/actor:0.023 - perf/mfu/actor_log_prob:0.052 - perf/mfu/ref:0.021 - perf/mfu/rollout:0.079 - response_length/clip_ratio:0.610 - response_length/max:256.000 - response_length/mean:232.029 - response_length/min:76.000 - prompt_length/clip_ratio:0.000 - prompt_length/max:189.000 - prompt_length/mean:104.727 - prompt_length/min:66.000 - perf/total_num_tokens:431047.000 - perf/time_per_step:129.426 - perf/throughput:3330.450
step:2 - training/epoch:1.000 - training/global_step:1.000 - training/rollout_probs_diff_max:0.326 - training/rollout_probs_diff_mean:0.004 - training/rollout_probs_diff_std:0.009 - actor/entropy_loss:0.432 - actor/grad_norm:0.210 - actor/lr:0.000 - actor/pg_clipfrac:0.000 - actor/pg_clipfrac_lower:0.000 - actor/pg_loss:0.004 - actor/ppo_kl:-0.000 - critic/advantages/max:1.789 - critic/advantages/mean:-0.004 - critic/advantages/min:-0.730 - critic/returns/max:1.789 - critic/returns/mean:-0.004 - critic/returns/min:-0.730 - critic/rewards/max:1.000 - critic/rewards/mean:0.013 - critic/rewards/min:0.000 - critic/score/max:1.000 - critic/score/mean:0.013 - critic/score/min:0.000 - perf/cpu_mem_used_gb:11.589 - perf/cpu_memory_used_gb:125.617 - perf/delta_time/actor:72.457 - perf/delta_time/actor_log_prob:10.689 - perf/delta_time/advantage:0.040 - perf/delta_time/compute_core_metrics:0.001 - perf/delta_time/data_loading:0.005 - perf/delta_time/get_data_from_buffer:0.001 - perf/delta_time/get_entry_node:0.000 - perf/delta_time/get_intern_data_actor_old_log_prob:0.000 - perf/delta_time/get_intern_data_actor_train:0.000 - perf/delta_time/get_intern_data_calculate_advantages:0.000 - perf/delta_time/get_intern_data_function_reward:0.000 - perf/delta_time/get_intern_data_reference_log_prob:0.000 - perf/delta_time/get_next_node:0.000 - perf/delta_time/graph_execution:123.794 - perf/delta_time/graph_loop_management:0.001 - perf/delta_time/graph_output_handling:0.002 - perf/delta_time/put_data_to_buffer:0.001 - perf/delta_time/put_intern_data_actor_old_log_prob:0.000 - perf/delta_time/put_intern_data_actor_train:0.000 - perf/delta_time/put_intern_data_calculate_advantages:0.000 - perf/delta_time/put_intern_data_function_reward:0.000 - perf/delta_time/put_intern_data_reference_log_prob:0.000 - perf/delta_time/reduce_metrics:0.001 - perf/delta_time/ref:24.271 - perf/delta_time/reference:24.273 - perf/delta_time/reset_data_buffer:0.005 - perf/delta_time/reset_intern_data_buffer:0.000 - perf/delta_time/reward:0.286 - perf/delta_time/rollout:16.043 - perf/delta_time/step:123.805 - perf/delta_time/step_barrier:0.001 - perf/max_mem_alloc_gb:36.362 - perf/max_mem_rsvd_gb:41.596 - perf/max_memory_allocated_gb:36.362 - perf/max_memory_reserved_gb:41.596 - perf/mfu/actor:0.023 - perf/mfu/actor_log_prob:0.053 - perf/mfu/ref:0.024 - perf/mfu/rollout:0.082 - response_length/clip_ratio:0.595 - response_length/max:256.000 - response_length/mean:230.901 - response_length/min:20.000 - prompt_length/clip_ratio:0.000 - prompt_length/max:215.000 - prompt_length/mean:105.098 - prompt_length/min:65.000 - perf/total_num_tokens:430078.000 - perf/time_per_step:123.805 - perf/throughput:3473.837
Beside, we provides a formatted, easy-to-read summary of core performance metrics on rank 0. This provides a clear, separate view of the most important indicators.
.. code-block:: bash
========================= RANK(0): Core Performance Metrics (Step: 1) =========================
--- ⏱️ Overall Performance ---
Step Time : 129.426 s
Throughput (tokens/s) : 3330.45
Total Tokens in Step : 431047
--- 📈 Algorithm Metrics ---
Actor Entropy : 0.4380
Critic Rewards (Mean/Min/Max): 0.013 / 0.000 / 1.000
Critic Scores (Mean/Min/Max): 0.013 / 0.000 / 1.000
--- 🔥 Model Flops Utilization (MFU) ---
Mean MFU : N/A
Actor Training MFU : 0.023
Rollout MFU : 0.079
Reference Policy MFU : 0.021
Actor LogProb MFU : 0.052
--- 💾 Memory Usage ---
Max GPU Memory Allocated : 34.83 GB
Max GPU Memory Reserved : 39.68 GB
CPU Memory Used : 11.58 GB
--- 📏 Sequence Lengths ---
Prompt Length (Mean/Max) : 104.7 / 189
Response Length (Mean/Max) : 232.0 / 256
==================================================================================
Checkout ``Algorithm Baselines`` page for full training and validation logs for reference.
If you encounter out of memory issues with HBM less than 32GB, enable the following configs would help:
.. code-block:: bash
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
critic.ppo_micro_batch_size_per_gpu=1 \
For the full set of configs, please refer to :ref:`config-explain-page` for detailed explanation and performance tuning.
.. [1] More training script examples for FSDP backend are stored in `examples/ppo_trainer `_ directory.
================================================
FILE: docs/user_interface/filter_interface.rst
================================================
================
Filter Interface
================
Filter interface is used for dynamic sampling and data filtering in Pipelines.
**Location:** ``siirl/user_interface/filter_interface/``
Architecture Overview
---------------------
::
Filter Interface Architecture
==============================================================================
+------------------+ +-------------------+ +------------------+
| Previous Node | | Filter Node | | Next Node |
| (e.g. Reward) |---->| (COMPUTE type) |---->| (e.g. Advantage)|
+------------------+ +-------------------+ +------------------+
|
v
+---------------+
| Filter Logic |
+---------------+
| 1. Get batch |
| 2. Compute |
| mask |
| 3. Apply |
| filter |
| 4. Return |
| NodeOutput |
+---------------+
==============================================================================
Filter Execution Flow:
Input Batch Filter Function Output
+-----------+ +-------------+ +-----------+
| samples | | | | filtered |
| [0,1,2,3, | -------> | mask = | -------> | samples |
| 4,5,6,7] | | [T,T,F,T, | | [0,1,3,5] |
+-----------+ | F,T,F,F] | +-----------+
+-------------+
|
v
+-------------+
| Metrics: |
| kept_ratio |
| kept_groups |
+-------------+
Built-in Filters
----------------
DAPO Dynamic Sampling
~~~~~~~~~~~~~~~~~~~~~
**Location:** ``siirl/user_interface/filter_interface/dapo.py``
**Function:** ``dynamic_sampling()``
Filters zero-variance sample groups (all correct or all incorrect).
**Flow Diagram:**
::
Input: Batch with rewards grouped by uid (prompt)
+-----------------------------------------------------------+
| uid=0: [1.0, 1.0, 1.0, 1.0] -> std=0 -> FILTER OUT |
| uid=1: [1.0, 0.0, 1.0, 0.0] -> std>0 -> KEEP |
| uid=2: [0.0, 0.0, 0.0, 0.0] -> std=0 -> FILTER OUT |
| uid=3: [0.5, 0.8, 0.2, 0.9] -> std>0 -> KEEP |
+-----------------------------------------------------------+
Output: Only uid=1 and uid=3 samples remain
**How it works:**
1. Group samples by uid (prompt)
2. Calculate variance for each group
3. Filter groups with variance = 0
**Configuration:**
.. code-block:: bash
python -m siirl.main_dag \
algorithm.workflow_type=DAPO \
algorithm.filter_groups.enable=true \
algorithm.filter_groups.metric=seq_final_reward
**Usage in Pipeline:**
.. code-block:: python
pipeline.add_node(
"dynamic_sampling",
func="siirl.user_interface.filter_interface.dapo:dynamic_sampling",
deps=["function_reward"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.DYNAMIC_SAMPLING
)
**Returned Metrics:**
- ``dapo_sampling/kept_trajectories_ratio``
- ``dapo_sampling/kept_groups``
- ``dapo_sampling/total_groups``
Embodied AI Sampling
~~~~~~~~~~~~~~~~~~~~
**Location:** ``siirl/user_interface/filter_interface/embodied.py``
**Function:** ``embodied_local_rank_sampling()``
Filters Embodied AI data based on task completion and accuracy.
**Flow Diagram:**
::
Input: Embodied rollout batch
+-----------------------------------------------------------------------+
| |
| Step 1: verify() - Compute accuracy from 'complete' field |
| +-------------------------------------------------------------------+|
| | Sample 0: complete=True -> acc=1.0 ||
| | Sample 1: complete=False -> acc=0.0 ||
| | ... ||
| +-------------------------------------------------------------------+|
| |
| Step 2: _filter_batch() - Apply filters |
| +-------------------------------------------------------------------+|
| | Accuracy Filter (per prompt group): ||
| | prompt_mean_acc >= lower_bound (0.1) AND ||
| | prompt_mean_acc <= upper_bound (0.9) ||
| | ||
| | Truncation Filter: ||
| | finish_step < max_steps (not truncated) ||
| +-------------------------------------------------------------------+|
| |
+-----------------------------------------------------------------------+
Output: Filtered batch (only "learnable" samples)
**Features:**
- Task verification
- Accuracy-based filtering
- Truncated trajectory filtering
**Configuration:**
.. code-block:: bash
python -m siirl.main_dag \
algorithm.workflow_type=EMBODIED \
algorithm.embodied_sampling.filter_accuracy=true \
algorithm.embodied_sampling.filter_truncated=true \
algorithm.embodied_sampling.accuracy_lower_bound=0.0 \
algorithm.embodied_sampling.accuracy_upper_bound=1.0 \
actor_rollout_ref.embodied.env.max_steps=100
**Usage in Pipeline:**
.. code-block:: python
pipeline.add_node(
"dynamic_sampling",
func="siirl.user_interface.filter_interface.embodied:embodied_local_rank_sampling",
deps=["rollout_actor"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.DYNAMIC_SAMPLING
)
Custom Filter
-------------
Basic Template
~~~~~~~~~~~~~~
.. code-block:: python
from siirl.params import SiiRLArguments
from siirl.dag_worker.data_structures import NodeOutput
from siirl.data_coordinator.sample import filter_tensordict
import torch
def my_custom_filter(
config: SiiRLArguments,
batch,
**kwargs
) -> NodeOutput:
"""Custom filter function"""
# Get data
rewards = batch.batch["rewards"]
# Create filter mask
mask = rewards > threshold # Boolean tensor
# Apply filter
filtered_batch = filter_tensordict(batch, mask)
# Collect metrics
metrics = {
"filter/kept_ratio": mask.sum().item() / len(mask)
}
return NodeOutput(batch=filtered_batch, metrics=metrics)
Example: Reward Threshold Filter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
def reward_threshold_filter(
config: SiiRLArguments,
batch,
**kwargs
) -> NodeOutput:
"""Filter samples below reward threshold"""
rewards = batch.batch["rewards"]
threshold = config.algorithm.filter_threshold
# Create mask
mask = rewards > threshold
# Apply filter
from siirl.data_coordinator.sample import filter_tensordict
filtered_batch = filter_tensordict(batch, mask)
# Metrics
metrics = {
"filter/kept_ratio": mask.sum().item() / len(mask),
"filter/threshold": threshold
}
return NodeOutput(batch=filtered_batch, metrics=metrics)
**Configuration:**
.. code-block:: bash
python -m siirl.main_dag \
algorithm.filter_threshold=0.5
**Usage in Pipeline:**
.. code-block:: python
pipeline.add_node(
"reward_filter",
func="my_module:reward_threshold_filter",
deps=["function_reward"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.DYNAMIC_SAMPLING
)
Example: Group Variance Filter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
from collections import defaultdict
def group_variance_filter(
config: SiiRLArguments,
batch,
**kwargs
) -> NodeOutput:
"""Filter groups with low variance"""
rewards = batch.batch["rewards"]
uids = batch.batch["uid"]
# Group by uid
uid_to_rewards = defaultdict(list)
for i, uid in enumerate(uids):
uid_key = int(uid) if hasattr(uid, 'item') else uid
uid_to_rewards[uid_key].append(rewards[i].item())
# Calculate std for each group
min_std = config.algorithm.min_group_std
kept_uids = {
uid for uid, r in uid_to_rewards.items()
if torch.std(torch.tensor(r)).item() >= min_std
}
# Create mask
mask = torch.tensor([
(int(uids[i]) if hasattr(uids[i], 'item') else uids[i]) in kept_uids
for i in range(len(uids))
], dtype=torch.bool)
# Apply filter
from siirl.data_coordinator.sample import filter_tensordict
filtered_batch = filter_tensordict(batch, mask)
metrics = {
"filter/kept_groups": len(kept_uids),
"filter/total_groups": len(uid_to_rewards)
}
return NodeOutput(batch=filtered_batch, metrics=metrics)
================================================
FILE: docs/user_interface/metrics_interface.rst
================================================
=================
Metrics Interface
=================
Custom metrics allow you to track and aggregate any quantitative measures during training and validation. siiRL provides a distributed, Ray-based metrics system that automatically handles aggregation across all workers using various reduction operations (mean, max, min, sum).
Architecture Overview
---------------------
::
Distributed Metrics Architecture
==============================================================================
DAGWorker 0 DAGWorker 1 DAGWorker 2 DAGWorker N
+-----------+ +-----------+ +-----------+ +-----------+
| compute | | compute | | compute | | compute |
| metrics | | metrics | | metrics | | metrics |
+-----+-----+ +-----+-----+ +-----+-----+ +-----+-----+
| | | |
v v v v
+-----+-----+ +-----+-----+ +-----+-----+ +-----+-----+
| Metric | | Metric | | Metric | | Metric |
| Client | | Client | | Client | | Client |
+-----+-----+ +-----+-----+ +-----+-----+ +-----+-----+
| | | |
+------------------+------------------+------------------+
|
v
+-------------------+
| MetricWorker | (Ray Actor - Singleton)
| (Aggregator) |
+-------------------+
| - Collect metrics |
| - Wait for all |
| workers |
| - Aggregate: |
| mean/max/min/ |
| sum |
+--------+----------+
|
v
+-------------------+
| Final Metrics |
| (to Logger/WandB) |
+-------------------+
==============================================================================
Metrics Data Flow:
+-------------+ +----------------+ +----------------+ +--------+
| TensorDict | --> | compute_* | --> | MetricClient | --> | Metric |
| (batch) | | _metric() | | .submit_metric | | Worker |
+-------------+ +----------------+ +----------------+ +--------+
|
v
+-------------+
| Dict[str, |
| float] |
| {name: val} |
+-------------+
==============================================================================
**Key Files:**
- ``siirl/execution/metric_worker/metric_worker.py`` - Ray-based distributed metrics aggregation
- ``siirl/utils/metrics/metric_utils.py`` - Core metric computation functions
- ``siirl/execution/metric_worker/utils.py`` - Aggregation function utilities
Quick Start
-----------
Method 1: Extending Core Metrics Functions (Recommended)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
**Step 1:** Create your metric computation function in ``metric_utils.py``
.. code-block:: python
# Add to siirl/utils/metrics/metric_utils.py
def compute_custom_data_metrics(data: TensorDict) -> Dict[str, float]:
"""Custom metrics computed from batch data"""
metrics = {}
# Token-level accuracy
if "correct_tokens" in data and "attention_mask" in data:
correct = data["correct_tokens"].float()
mask = data["attention_mask"].float()
accuracy = (correct * mask).sum() / mask.sum()
metrics["custom/token_accuracy/mean"] = accuracy.item()
# Response quality score
if "responses" in data and "response_mask" in data:
response_quality = compute_response_quality_score(data)
metrics["custom/response_quality/mean"] = response_quality.mean().item()
metrics["custom/response_quality/max"] = response_quality.max().item()
metrics["custom/response_quality/min"] = response_quality.min().item()
return metrics
def compute_response_quality_score(data: TensorDict) -> torch.Tensor:
"""Helper function to compute response quality"""
responses = data["responses"]
response_mask = data["response_mask"]
# Example: vocabulary diversity score
unique_tokens_per_response = []
for i in range(responses.shape[0]):
response_tokens = responses[i][response_mask[i].bool()]
unique_count = len(torch.unique(response_tokens))
unique_tokens_per_response.append(unique_count)
return torch.tensor(unique_tokens_per_response, device=responses.device).float()
**Step 2:** Submit metrics using MetricClient
.. code-block:: python
# Usage in your training loop
from siirl.execution.metric_worker.metric_worker import MetricClient
# In your DAG worker or training script
custom_metrics = compute_custom_data_metrics(batch)
metric_client.submit_metric(custom_metrics, world_size)
Current Metrics System
----------------------
Built-in Metrics Reference
~~~~~~~~~~~~~~~~~~~~~~~~~~
The following tables list all built-in metrics provided by siiRL.
**Data Metrics** (from ``compute_data_metric`` in ``metric_utils.py``):
.. list-table:: Critic Metrics
:header-rows: 1
:widths: 40 60
* - Metric Name
- Description
* - ``critic/score/mean|max|min``
- Sequence-level scores from token-level scores
* - ``critic/rewards/mean|max|min``
- Sequence-level rewards from token-level rewards
* - ``critic/advantages/mean|max|min``
- Advantages (masked by response_mask)
* - ``critic/returns/mean|max|min``
- Returns (masked by response_mask)
* - ``critic/values/mean|max|min``
- Value function estimates (if available)
* - ``critic/vf_explained_var``
- Explained variance of value function
.. list-table:: Response Analysis Metrics
:header-rows: 1
:widths: 40 60
* - Metric Name
- Description
* - ``response/length/mean|max|min``
- Response token lengths
* - ``response/clip_ratio/mean``
- Proportion hitting max response length
* - ``response/correct_length/mean|max|min``
- Lengths for responses with reward > 0.5
* - ``response/wrong_length/mean|max|min``
- Lengths for responses with reward ≤ 0.5
.. list-table:: Prompt Analysis Metrics
:header-rows: 1
:widths: 40 60
* - Metric Name
- Description
* - ``prompt/length/mean|max|min``
- Prompt token lengths
* - ``prompt/clip_ratio/mean``
- Proportion hitting max prompt length
.. list-table:: System & Multi-turn Metrics
:header-rows: 1
:widths: 40 60
* - Metric Name
- Description
* - ``perf/process_cpu_mem_used_gb``
- CPU memory usage per process
* - ``num_turns/min|max|mean``
- Statistics for multi-turn conversations
**Timing Metrics** (from ``compute_timing_metrics``):
.. list-table::
:header-rows: 1
:widths: 40 60
* - Metric Name
- Description
* - ``timing_s/{stage}``
- Raw timing in seconds for each stage
* - ``timing_per_token_ms/{stage}``
- Per-token timing in milliseconds
Stages: ``gen``, ``ref``, ``values``, ``adv``, ``update_critic``, ``update_actor``
**Throughput Metrics** (from ``compute_throughout_metrics``):
.. list-table::
:header-rows: 1
:widths: 40 60
* - Metric Name
- Description
* - ``perf/total_num_tokens``
- Total tokens processed
* - ``perf/time_per_step``
- Time per training step
* - ``perf/throughput``
- Tokens per second per GPU
**Validation Metrics** (from ``process_validation_metrics``):
.. list-table::
:header-rows: 1
:widths: 40 60
* - Metric Name
- Description
* - ``val-core/{data_source}/{var}/mean@N``
- Mean across N samples
* - ``val-core/{data_source}/{var}/best@N/mean|std``
- Bootstrap best-of-N statistics
* - ``val-core/{data_source}/{var}/worst@N/mean|std``
- Bootstrap worst-of-N statistics
* - ``val-core/{data_source}/{var}/maj@N/mean|std``
- Bootstrap majority voting statistics
* - ``val/test_score/{data_source}``
- Test score per data source
Custom Metrics Implementation
-----------------------------
Method 1: Custom Data Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Extend the data metrics computed from training batches:
.. code-block:: python
# Add to metric_utils.py
def compute_custom_training_metrics(data: TensorDict) -> Dict[str, float]:
"""Custom training-specific metrics"""
metrics = {}
# Policy entropy (exploration measure)
if "policy_logits" in data:
logits = data["policy_logits"]
probs = torch.softmax(logits, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
response_mask = data.get("response_mask", torch.ones_like(entropy))
# Only compute entropy for response tokens
masked_entropy = entropy * response_mask.float()
valid_entropy = masked_entropy.sum() / response_mask.sum()
metrics["training/policy_entropy/mean"] = valid_entropy.item()
# Gradient norm tracking
if "grad_norm" in data:
metrics["training/grad_norm/mean"] = data["grad_norm"].item()
# Loss convergence tracking
if "loss_values" in data:
loss_values = data["loss_values"]
metrics["training/loss/mean"] = loss_values.mean().item()
metrics["training/loss/std"] = loss_values.std().item()
return metrics
# Usage in MetricClient.compute_local_data_metric
def compute_local_data_metric(self, data: TensorDict, world_size: int):
# Standard metrics
standard_metrics = compute_data_metric(data)
# Add custom metrics
custom_metrics = compute_custom_training_metrics(data)
# Combine and submit
all_metrics = {**standard_metrics, **custom_metrics}
self.submit_metric(all_metrics, world_size)
Method 2: Custom Validation Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Add custom validation metrics with bootstrap sampling:
.. code-block:: python
# Add to metric_utils.py
def compute_custom_validation_metrics(
data_sources: list[str],
sample_inputs: list[str],
infos_dict: dict[str, list],
sample_turns: list[int]
) -> dict[str, float]:
"""Custom validation metrics with bootstrap analysis"""
# Extract custom fields from infos_dict
custom_metrics = {}
if "custom_score" in infos_dict:
# Group by data source
source_scores = defaultdict(list)
for i, source in enumerate(data_sources):
source_scores[source].append(infos_dict["custom_score"][i])
# Compute statistics per source
for source, scores in source_scores.items():
if len(scores) > 0:
custom_metrics[f"val/custom_score/{source}/mean"] = np.mean(scores)
custom_metrics[f"val/custom_score/{source}/std"] = np.std(scores)
# Bootstrap sampling for confidence intervals
if len(scores) > 1:
bootstrap_results = bootstrap_metric(
data=scores,
subset_size=min(5, len(scores)),
reduce_fns=[np.mean, np.max, np.min],
n_bootstrap=1000
)
custom_metrics[f"val/custom_score/{source}/bootstrap_mean"] = bootstrap_results[0][0]
custom_metrics[f"val/custom_score/{source}/bootstrap_mean_std"] = bootstrap_results[0][1]
# Conversation quality for multi-turn
if "conversation_quality" in infos_dict and len(sample_turns) > 0:
quality_by_turns = defaultdict(list)
for i, turns in enumerate(sample_turns):
if i < len(infos_dict["conversation_quality"]):
quality_by_turns[turns].append(infos_dict["conversation_quality"][i])
for turn_count, qualities in quality_by_turns.items():
if len(qualities) > 0:
custom_metrics[f"val/conversation_quality/turns_{turn_count}/mean"] = np.mean(qualities)
return custom_metrics
# Usage in MetricClient.process_local_validation_metrics
def process_local_validation_metrics(self, data_sources, sample_inputs, infos_dict, sample_turns, world_size):
# Standard validation metrics
standard_metrics = process_validation_metrics(data_sources, sample_inputs, infos_dict, sample_turns)
# Add custom validation metrics
custom_metrics = compute_custom_validation_metrics(data_sources, sample_inputs, infos_dict, sample_turns)
# Combine and submit
all_metrics = {**standard_metrics, **custom_metrics}
self.submit_metric(all_metrics, world_size)
Method 3: Custom Aggregation Logic
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Create custom aggregation functions for specialized reduction operations:
.. code-block:: python
# Add to execution/metric_worker/utils.py
def MedianMetric(metrics: List[Metric]):
"""Custom median aggregation"""
values = [v for metric in metrics
for v in (metric.value if isinstance(metric.value, list) else [metric.value])]
return float(torch.median(torch.tensor(values)).item())
def PercentileMetric(percentile: float):
"""Custom percentile aggregation factory"""
def _percentile_metric(metrics: List[Metric]):
values = [v for metric in metrics
for v in (metric.value if isinstance(metric.value, list) else [metric.value])]
return float(torch.quantile(torch.tensor(values), percentile / 100.0).item())
return _percentile_metric
# Update MetricFunc to handle custom aggregations
def MetricFunc(name: str):
if "median" in name:
return MedianMetric
elif "p95" in name:
return PercentileMetric(95)
elif "p99" in name:
return PercentileMetric(99)
elif "min" in name:
return MinMetric
elif "max" in name:
return MaxMetric
elif "sum" in name or "total" in name:
return SumMetric
else:
return MeanMetric
# Usage: name your metrics to trigger specific aggregations
metrics = {
"custom/latency/median": latency_values, # Will use MedianMetric
"custom/score/p95": score_values, # Will use 95th percentile
}
Method 4: Complex Custom Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
For more sophisticated metrics requiring multiple computation steps:
.. code-block:: python
# Add to metric_utils.py
def compute_advanced_metrics(data: TensorDict) -> Dict[str, float]:
"""Advanced metrics requiring complex computation"""
metrics = {}
# Sequence coherence analysis
if "responses" in data and "attention_mask" in data:
coherence_scores = compute_sequence_coherence(data)
metrics.update({
"analysis/coherence/mean": coherence_scores.mean().item(),
"analysis/coherence/std": coherence_scores.std().item(),
"analysis/coherence/median": coherence_scores.median().item(),
})
# Token transition analysis
if "responses" in data:
transition_metrics = analyze_token_transitions(data)
metrics.update(transition_metrics)
# Reward distribution analysis
if "token_level_rewards" in data:
reward_dist_metrics = analyze_reward_distribution(data)
metrics.update(reward_dist_metrics)
return metrics
def compute_sequence_coherence(data: TensorDict) -> torch.Tensor:
"""Compute coherence score for each sequence"""
responses = data["responses"]
attention_mask = data["attention_mask"]
batch_size = responses.shape[0]
coherence_scores = []
for i in range(batch_size):
# Extract valid tokens for this sequence
valid_length = attention_mask[i].sum().item()
sequence = responses[i][:valid_length]
# Compute local coherence (e.g., token transition smoothness)
if len(sequence) > 1:
# Simplified coherence: variance in token values
coherence = 1.0 / (1.0 + torch.var(sequence.float()).item())
else:
coherence = 1.0
coherence_scores.append(coherence)
return torch.tensor(coherence_scores, device=responses.device)
def analyze_token_transitions(data: TensorDict) -> Dict[str, float]:
"""Analyze patterns in token transitions"""
responses = data["responses"]
response_mask = data.get("response_mask", torch.ones_like(responses))
# Count unique transitions
unique_transitions = set()
total_transitions = 0
for i in range(responses.shape[0]):
response_tokens = responses[i][response_mask[i].bool()]
if len(response_tokens) > 1:
for j in range(len(response_tokens) - 1):
transition = (response_tokens[j].item(), response_tokens[j+1].item())
unique_transitions.add(transition)
total_transitions += 1
diversity_ratio = len(unique_transitions) / max(total_transitions, 1)
return {
"analysis/transition_diversity/mean": diversity_ratio,
"analysis/unique_transitions/total": len(unique_transitions),
"analysis/total_transitions/total": total_transitions,
}
Integration with Training Workflow
----------------------------------
MetricClient Usage Pattern
~~~~~~~~~~~~~~~~~~~~~~~~~~~
The ``MetricClient`` provides the main interface for submitting metrics:
.. code-block:: python
from siirl.execution.metric_worker.metric_worker import MetricClient, MetricWorker
# Initialize metric worker and client
metric_worker = MetricWorker.remote()
await metric_worker.start.remote()
metric_client = MetricClient(metric_worker)
# During training loop
for step, batch in enumerate(dataloader):
# ... training logic ...
# Submit standard metrics
metric_client.compute_local_data_metric(batch, world_size)
# Submit custom metrics
custom_metrics = compute_advanced_metrics(batch)
metric_client.submit_metric(custom_metrics, world_size)
# Submit timing metrics
timing_data = {"step": step_time, "forward": forward_time}
metric_client.compute_local_timing_metrics(batch, timing_data, world_size)
# Wait for metrics to be processed
metric_client.wait_submit()
# Get final aggregated results
final_metrics = metric_client.wait_final_res()
Ray-based Distributed Aggregation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The system uses Ray actors for distributed metrics processing:
**MetricWorker Actor:**
- Runs asynchronously to collect metrics from all workers
- Aggregates metrics when all processes have submitted values
- Supports different aggregation functions (mean, max, min, sum)
- Automatically handles timing metric renaming (``timing_s/`` → ``perf/delta_time/``)
**Aggregation Logic:**
- Metrics are collected in a queue until all workers (``world_size``) submit
- Each metric triggers computation when the expected number of submissions is reached
- Final results are stored and returned when requested
Special Metric Configurations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Some metrics require special aggregation logic:
.. code-block:: python
# In metric_worker.py
Special_Metric = {
"graph_output_handling": MaxMetric, # Only rollout_tp 0 contributes
}
Custom metrics can be added to this dictionary for specialized handling.
Advanced Examples
-----------------
Example 1: Model Performance Analysis
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
def compute_model_performance_metrics(data: TensorDict, model_outputs: dict) -> Dict[str, float]:
"""Comprehensive model performance analysis"""
metrics = {}
# Attention pattern analysis
if "attention_weights" in model_outputs:
attention_weights = model_outputs["attention_weights"]
# Attention concentration (how focused is attention)
attention_entropy = -torch.sum(
attention_weights * torch.log(attention_weights + 1e-9), dim=-1
)
metrics["model/attention_entropy/mean"] = attention_entropy.mean().item()
# Attention on different token types
if "attention_mask" in data:
prompt_attention = attention_weights[:, :, :-data["responses"].shape[-1]]
response_attention = attention_weights[:, :, -data["responses"].shape[-1]:]
metrics["model/prompt_attention_ratio/mean"] = (
prompt_attention.sum() / attention_weights.sum()
).item()
# Hidden state analysis
if "hidden_states" in model_outputs:
hidden_states = model_outputs["hidden_states"]
# Representation diversity
layer_norms = torch.norm(hidden_states, dim=-1)
metrics["model/hidden_norm/mean"] = layer_norms.mean().item()
metrics["model/hidden_norm/std"] = layer_norms.std().item()
return metrics
Example 2: Conversation Quality Assessment
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
def compute_conversation_quality_metrics(data: TensorDict) -> Dict[str, float]:
"""Multi-dimensional conversation quality assessment"""
metrics = {}
if "responses" not in data or "prompts" not in data:
return metrics
responses = data["responses"]
prompts = data["prompts"]
response_mask = data.get("response_mask", torch.ones_like(responses))
batch_size = responses.shape[0]
quality_scores = []
for i in range(batch_size):
# Extract actual tokens (remove padding)
response_tokens = responses[i][response_mask[i].bool()]
prompt_tokens = prompts[i]
# Length appropriateness (not too short, not too long)
response_length = len(response_tokens)
length_score = compute_length_appropriateness(response_length)
# Vocabulary richness
unique_tokens = len(torch.unique(response_tokens))
vocab_score = min(unique_tokens / response_length, 1.0) if response_length > 0 else 0
# Repetition penalty
repetition_score = compute_repetition_score(response_tokens)
# Overall quality
quality = 0.3 * length_score + 0.3 * vocab_score + 0.4 * repetition_score
quality_scores.append(quality)
quality_tensor = torch.tensor(quality_scores, device=responses.device)
return {
"conversation/quality/mean": quality_tensor.mean().item(),
"conversation/quality/std": quality_tensor.std().item(),
"conversation/quality/min": quality_tensor.min().item(),
"conversation/quality/max": quality_tensor.max().item(),
}
def compute_length_appropriateness(length: int, target_length: int = 50) -> float:
"""Compute how appropriate the response length is"""
if length == 0:
return 0.0
ratio = length / target_length
if ratio <= 1.0:
return ratio # Shorter is better than longer
else:
return 1.0 / ratio # Penalize overly long responses
def compute_repetition_score(tokens: torch.Tensor) -> float:
"""Compute score based on repetition patterns"""
if len(tokens) <= 1:
return 1.0
# Count repeated bigrams
bigrams = set()
repeated_bigrams = 0
for i in range(len(tokens) - 1):
bigram = (tokens[i].item(), tokens[i+1].item())
if bigram in bigrams:
repeated_bigrams += 1
else:
bigrams.add(bigram)
# Higher repetition = lower score
repetition_ratio = repeated_bigrams / (len(tokens) - 1)
return 1.0 - repetition_ratio
Configuration and Best Practices
---------------------------------
Metric Naming Conventions
~~~~~~~~~~~~~~~~~~~~~~~~~~
Follow these conventions for consistent metric organization:
.. code-block:: text
# Training metrics
training/{category}/{metric_name}/{aggregation}
# Validation metrics
val/{category}/{data_source}/{metric_name}
val-core/{data_source}/{variable}/{metric_name}
val-aux/{category}/{metric_name}
# Performance metrics
perf/{metric_name}
# Analysis metrics
analysis/{category}/{metric_name}/{aggregation}
# Model introspection
model/{component}/{metric_name}/{aggregation}
Aggregation Selection
~~~~~~~~~~~~~~~~~~~~~
Choose aggregation methods based on metric semantics:
- **mean**: Default for most metrics (accuracy, loss, etc.)
- **max**: For peak values (max memory, worst-case latency)
- **min**: For best-case scenarios (min loss, fastest response)
- **sum/total**: For cumulative values (total tokens, total time)
- **median**: For robust central tendency (when outliers matter)
- **p95/p99**: For percentile-based SLA metrics
Error Handling
~~~~~~~~~~~~~~
Always implement robust error handling:
.. code-block:: python
def compute_safe_custom_metrics(data: TensorDict) -> Dict[str, float]:
"""Example of safe metric computation"""
metrics = {}
try:
# Check data availability
if "required_field" not in data:
return metrics
# Handle empty tensors
values = data["required_field"]
if values.numel() == 0:
return metrics
# Compute metrics with numerical stability
mean_val = torch.mean(values.float())
if torch.isfinite(mean_val):
metrics["custom/metric/mean"] = mean_val.item()
except Exception as e:
# Log error but don't crash training
print(f"Error computing custom metrics: {e}")
return {}
return metrics
Performance Considerations
~~~~~~~~~~~~~~~~~~~~~~~~~~
- **Batch Processing**: Compute metrics on entire batches, not individual samples
- **Device Placement**: Keep tensors on the same device as input data
- **Memory Management**: Avoid accumulating large tensors across steps
- **Async Processing**: Use Ray actors for non-blocking metrics aggregation
- **Selective Computation**: Only compute expensive metrics when needed
Debugging Custom Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
import os
def debug_custom_metrics(data: TensorDict, metrics: Dict[str, float]):
"""Debug utility for custom metrics"""
if os.environ.get("DEBUG_METRICS", "0") == "1":
print(f"Data keys: {list(data.keys())}")
print(f"Data shapes: {[(k, v.shape if hasattr(v, 'shape') else type(v)) for k, v in data.items()]}")
print(f"Computed metrics: {metrics}")
# Check for common issues
for name, value in metrics.items():
if not isinstance(value, (int, float)):
print(f"WARNING: Metric {name} has invalid type {type(value)}")
elif not np.isfinite(value):
print(f"WARNING: Metric {name} is not finite: {value}")
File Structure Summary
----------------------
.. code-block:: text
siirl/execution/metric_worker/
├── metric_worker.py # Ray actor for distributed aggregation
│ ├── MetricWorker # Ray remote actor class
│ └── MetricClient # Client interface
└── utils.py # Aggregation functions
├── Metric # Dataclass for metric values
├── MetricFunc # Function selection logic
├── MeanMetric # Mean aggregation
├── MaxMetric # Maximum aggregation
├── MinMetric # Minimum aggregation
└── SumMetric # Sum aggregation
siirl/utils/metrics/
└── metric_utils.py # Core metric computation
├── compute_data_metric # Standard training metrics
├── compute_timing_metrics # Timing analysis
├── compute_throughout_metrics # Throughput analysis
├── process_validation_metrics # Validation with bootstrap
├── bootstrap_metric # Bootstrap sampling utility
└── aggregate_validation_metrics # Parallel validation processing
This architecture provides a scalable, flexible foundation for comprehensive metrics collection in distributed training environments.
================================================
FILE: docs/user_interface/pipeline_interface.rst
================================================
============
Pipeline API
============
Pipeline is a declarative Python API for defining training workflows. Each Pipeline consists of Nodes connected through dependencies to form a DAG.
Architecture Overview
---------------------
::
Pipeline Architecture
==============================================================================
+------------------+ +------------------+
| Pipeline | .build() | TaskGraph |
| (Builder) | ------------------> | (DAG) |
+------------------+ +------------------+
| - pipeline_id | | - graph_id |
| - description | | - nodes: Dict |
| - _nodes: Dict | | - adj: Dict |
+------------------+ | - rev_adj: Dict |
+------------------+
|
| executed by
v
+------------------+
| DAGWorker |
| (per GPU) |
+------------------+
==============================================================================
Built-in Pipelines Comparison:
+----------+------------------------------------------------------------------+
| Pipeline | Nodes Flow |
+----------+------------------------------------------------------------------+
| GRPO | rollout -> reward -> advantage -> old_log -> ref_log -> train |
+----------+------------------------------------------------------------------+
| PPO | rollout -> reward -> value -> advantage -> old_log -> ref_log |
| | -> train_actor -> train_critic |
+----------+------------------------------------------------------------------+
| DAPO | rollout -> reward -> dynamic_sampling -> advantage -> old_log |
| | -> ref_log -> train |
+----------+------------------------------------------------------------------+
| Embodied | rollout -> embodied_sampling -> reward -> advantage -> old_log |
| SRPO | -> ref_log -> train |
+----------+------------------------------------------------------------------+
Basic Usage
-----------
Creating a Pipeline
~~~~~~~~~~~~~~~~~~~
.. code-block:: python
from siirl.execution.dag.pipeline import Pipeline
from siirl.execution.dag.node import NodeType, NodeRole
pipeline = Pipeline("my_pipeline", "Description")
# Add nodes (supports chaining)
pipeline.add_node(
"node_id",
func="module:function", # or "module:Class.method"
deps=["dependency_node_ids"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.DEFAULT
).add_node(
"next_node",
func="module:another_function",
deps=["node_id"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR
)
# Build TaskGraph
task_graph = pipeline.build()
Node Parameters
~~~~~~~~~~~~~~~
- ``node_id``: Unique identifier
- ``func``: Function path (``"module:function"`` or ``"module:Class.method"``)
- ``deps``: List of dependency node IDs
- ``node_type``: MODEL_INFERENCE / MODEL_TRAIN / COMPUTE / DATA_LOAD
- ``node_role``: ROLLOUT / ACTOR / CRITIC / REFERENCE / REWARD / ADVANTAGE / DYNAMIC_SAMPLING / DEFAULT
- ``only_forward_compute``: Forward only (default False)
Built-in Pipelines
------------------
siiRL provides 4 built-in pipelines in ``siirl/execution/dag/builtin_pipelines.py``:
GRPO Pipeline
~~~~~~~~~~~~~
**Workflow:** rollout → reward → advantage → old_log_prob → ref_log_prob → train_actor
**Usage:**
.. code-block:: bash
python -m siirl.main_dag \
algorithm.adv_estimator=grpo
PPO Pipeline
~~~~~~~~~~~~
**Workflow:** rollout → reward → critic_value → advantage → old_log_prob → ref_log_prob → train_actor → train_critic
**Key Difference:** Adds value function and critic training
**Usage:**
.. code-block:: bash
python -m siirl.main_dag \
algorithm.adv_estimator=gae \
critic.enable=true
DAPO Pipeline
~~~~~~~~~~~~~
**Workflow:** rollout → reward → dynamic_sampling → advantage → old_log_prob → ref_log_prob → train_actor
**Key Feature:** Filters zero-variance sample groups
**Usage:**
.. code-block:: bash
python -m siirl.main_dag \
algorithm.workflow_type=DAPO \
algorithm.filter_groups.enable=true
Embodied GRPO Pipeline
~~~~~~~~~~~~~~~~~~~~~~~
**Workflow:** rollout → embodied_sampling → reward → advantage → old_log_prob → ref_log_prob → train_actor
**Key Feature:** Embodied AI specific filtering
**Usage:**
.. code-block:: bash
python -m siirl.main_dag \
algorithm.workflow_type=EMBODIED
Custom Pipeline Definition
---------------------------
Define Custom Pipeline
~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
from siirl.execution.dag.pipeline import Pipeline
from siirl.execution.dag.task_graph import TaskGraph
from siirl.execution.dag.node import NodeType, NodeRole
def my_custom_pipeline() -> TaskGraph:
pipeline = Pipeline("my_pipeline", "My workflow")
pipeline.add_node(
"rollout_actor",
func="siirl.dag_worker.dagworker:DAGWorker.generate",
deps=[],
node_type=NodeType.MODEL_INFERENCE,
node_role=NodeRole.ROLLOUT
).add_node(
"my_custom_node",
func="my_module:my_function",
deps=["rollout_actor"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.DEFAULT
)
return pipeline.build()
Custom Node Function
~~~~~~~~~~~~~~~~~~~~
Node functions must follow this signature:
.. code-block:: python
from siirl.dag_worker.data_structures import NodeOutput
def my_function(batch, config=None, **kwargs) -> NodeOutput:
"""
Args:
batch: Input data (TensorDict)
config: Global configuration
**kwargs: Additional arguments
Returns:
NodeOutput(batch=processed_batch, metrics={})
"""
# Process batch
processed_batch = process(batch)
# Collect metrics
metrics = {"metric_name": value}
return NodeOutput(batch=processed_batch, metrics=metrics)
Use Custom Pipeline
~~~~~~~~~~~~~~~~~~~
**Command Line:**
.. code-block:: bash
python -m siirl.main_dag \
dag.custom_pipeline_fn="my_module:my_custom_pipeline"
================================================
FILE: docs/user_interface/reward_interface.rst
================================================
================
Reward Interface
================
Custom reward functions allow you to score model-generated responses. Simply write a Python function and specify its path in configuration.
**Official Example:** ``siirl/user_interface/rewards_interface/custom_gsm8k_reward.py``
Architecture Overview
---------------------
::
Reward Computation Flow
==============================================================================
+------------------+ +-------------------+ +------------------+
| Rollout Node | | Reward Node | | Advantage Node |
| (Generation) |---->| (Scoring) |---->| (Normalization) |
+------------------+ +-------------------+ +------------------+
|
v
+---------------+
| RewardManager |
+---------------+
|
+------------------------+------------------------+
| | |
v v v
+---------------+ +---------------+ +---------------+
| Naive Reward | | Batch Reward | | Custom Reward |
| (Rule-based) | | (Model-based) | | (User-defined)|
+---------------+ +---------------+ +---------------+
|
v
+---------------+
| compute_score |
| (data_source, |
| solution_str,|
| ground_truth,|
| extra_info) |
+-------+-------+
|
v
+---------------+
| Returns float |
| score [0, 1] |
+---------------+
==============================================================================
Custom Reward Function Integration:
Configuration Runtime
+---------------------------+ +---------------------------+
| custom_reward_function: | | RewardManager loads |
| path: /path/to/file.py | -----> | compute_score function |
| name: compute_score | | and calls it per sample |
+---------------------------+ +---------------------------+
Quick Start
-----------
Step 1: Write Reward Function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Create a Python file with ``compute_score`` function:
.. code-block:: python
# my_reward.py
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""
Custom reward function
Args:
data_source (str): Dataset source identifier (e.g., "openai/gsm8k")
solution_str (str): Model generated text
ground_truth (str): Correct answer
extra_info (dict): Additional information (optional)
Returns:
float: Score (typically 0-1)
"""
# Your scoring logic
if solution_str == ground_truth:
return 1.0
else:
return 0.0
Step 2: Configuration
~~~~~~~~~~~~~~~~~~~~~
**Command Line:**
.. code-block:: bash
python -m siirl.main_dag \
custom_reward_function.path=/path/to/my_reward.py \
custom_reward_function.name=compute_score
Official Example: GSM8K
-----------------------
**File:** ``siirl/user_interface/rewards_interface/custom_gsm8k_reward.py``
.. code-block:: python
import re
def extract_solution(solution_str, method="strict"):
"""Extract answer from solution"""
if method == "strict":
# Requires #### format
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
if solution is None:
return None
final_answer = solution.group(1).replace(",", "")
return final_answer
elif method == "flexible":
# Extract last number
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
if len(answer) == 0:
return None
for final_answer in reversed(answer):
if final_answer not in ["", "."]:
return final_answer
return None
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""
GSM8K scoring function
Checks format and compares answer
"""
method = "strict"
format_score = 0.0
score = 1.0
answer = extract_solution(solution_str, method=method)
if answer is None:
return 0 # Format error
elif answer == ground_truth:
return score # Correct answer
else:
return format_score # Correct format but wrong answer
**Usage:**
.. code-block:: bash
python -m siirl.main_dag \
custom_reward_function.path=siirl/user_interface/rewards_interface/custom_gsm8k_reward.py \
custom_reward_function.name=compute_score \
data.train_files=/path/to/gsm8k.parquet
Custom Examples
---------------
Example 1: Keyword Matching
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""Keyword-based reward"""
score = 0.0
# Check keywords
keywords = ["because", "therefore", "thus"]
for keyword in keywords:
if keyword in solution_str.lower():
score += 0.3
# Length check
words = len(solution_str.split())
if 50 <= words <= 200:
score += 0.4
return min(score, 1.0)
Example 2: Regex Validation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
import re
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""Regex-based format validation"""
# Extract numeric answer
match = re.search(r"答案[是为][::]\s*(\d+)", solution_str)
if match is None:
return 0.0 # Incorrect format
answer = match.group(1)
if answer == ground_truth:
return 1.0 # Correct
else:
return 0.1 # Correct format but wrong answer
Example 3: Multi-stage Scoring
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
import re
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""Multi-stage scoring: format + reasoning + correctness"""
score = 0.0
# Stage 1: Format check (0.2 points)
if "####" in solution_str:
score += 0.2
# Stage 2: Reasoning steps (0.3 points)
steps = solution_str.count('\n')
if steps >= 3:
score += 0.3
# Stage 3: Answer correctness (0.5 points)
answer_match = re.search(r"#### ([\-0-9\.]+)", solution_str)
if answer_match:
answer = answer_match.group(1)
if answer == ground_truth:
score += 0.5
return score
Example 4: Multiple Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""Route to different scoring functions based on data_source"""
if data_source == "gsm8k":
return score_gsm8k(solution_str, ground_truth)
elif data_source == "math":
return score_math(solution_str, ground_truth)
else:
return 0.0
def score_gsm8k(solution_str, ground_truth):
# GSM8K specific logic
pass
def score_math(solution_str, ground_truth):
# MATH specific logic
pass
Function Specification
----------------------
Required Signature
~~~~~~~~~~~~~~~~~~
.. code-block:: python
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""
Args:
data_source (str): Dataset source
solution_str (str): Model generated response
ground_truth (str): Correct answer
extra_info (dict): Additional information
Returns:
float: Score value
"""
pass
Important Notes
~~~~~~~~~~~~~~~
1. **Function Name:** Can be customized, specify via ``custom_reward_function.name``
2. **Return Type:** Must return ``float``, typically in [0, 1] range
3. **Error Handling:** Recommended to catch exceptions and return default value (e.g., 0.0)
4. **Parameter Order:** Must follow the signature order
================================================
FILE: examples/cpgd_trainer/run_qwen2_5-7b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=cpgd
export MODEL_NAME=qwen2.5-7b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.policy_loss.loss_mode=cpgd
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=False
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/cpgd_trainer/run_qwen2_5_vl-72b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=mm_eureka
export ALG=cpgd
export MODEL_NAME=qwen2.5-vl-72b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-VL-72B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=128
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=8
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.policy_loss.loss_mode=cpgd
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.kl_ctrl.kl_coef=0.001
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.del_local_ckpt_after_load=False
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_IFNAME=bond0
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/cpgd_trainer/run_qwen2_5_vl-7b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=mm_eureka
export ALG=cpgd
export MODEL_NAME=qwen2.5-vl-7b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-VL-7B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.policy_loss.loss_mode=cpgd
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=False
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/cpgd_trainer/run_qwen3-1.7b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=cpgd
export MODEL_NAME=qwen3-1.7b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-1.7B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=1024
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.5
export ROLLOUT_TP=1
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.policy_loss.loss_mode=cpgd
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/cpgd_trainer/run_qwen3-8b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=cpgd
export MODEL_NAME=qwen3-8b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-8B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=1024
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.5
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.policy_loss.loss_mode=cpgd
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/custom_pipeline_example/custom_grpo.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Custom Pipeline Examples
This file demonstrates how users can define custom training pipelines
using the new Pipeline API. All functions are explicitly visible in the code.
"""
import numpy as np
from siirl.execution.dag.pipeline import Pipeline, NodeConfig
from siirl.execution.dag.task_graph import TaskGraph
from tensordict import TensorDict
from siirl.dag_worker.data_structures import NodeOutput
# ============================================================================
# Example 1: Use Built-in Pipeline (Simplest)
# ============================================================================
def example_builtin_grpo() -> TaskGraph:
"""
Simplest way: Use built-in GRPO pipeline directly.
This is recommended for users who want to use standard algorithms
without customization.
"""
from siirl.execution.dag.builtin_pipelines import grpo_pipeline
return grpo_pipeline()
# ============================================================================
# Example 2: GRPO with Custom Reward Function
# ============================================================================
def grpo_with_custom_reward() -> TaskGraph:
"""
Customize the reward computation while keeping other parts standard.
This example shows how to replace the reward node with a custom function
while keeping the rest of the pipeline standard.
"""
pipeline = Pipeline(
"grpo_custom_reward",
"GRPO pipeline with custom reward function"
)
# Standard rollout
pipeline.add_node(
"rollout_actor",
func="siirl.dag_worker.dagworker:DAGWorker.generate",
deps=[]
)
# Custom reward function (user's own implementation)
pipeline.add_node(
"custom_reward",
func="examples.custom_pipeline_example.custom_grpo:my_custom_reward_fn",
deps=["rollout_actor"]
)
# Standard advantage calculation and training
pipeline.add_node(
"calculate_advantages",
func="siirl.dag_worker.dagworker:DAGWorker.compute_advantage",
deps=["custom_reward"]
).add_node(
"actor_old_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob",
deps=["calculate_advantages"],
only_forward_compute=True
).add_node(
"reference_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob",
deps=["actor_old_log_prob"]
).add_node(
"actor_train",
func="siirl.dag_worker.dagworker:DAGWorker.train_actor",
deps=["reference_log_prob"]
)
return pipeline.build()
def my_custom_reward_fn(batch: TensorDict, **kwargs) -> NodeOutput:
"""
User's custom reward function.
This function can implement any custom reward logic.
Here we show a simple example, but users can implement
arbitrarily complex reward computations.
Args:
batch: TensorDict containing prompts and responses
**kwargs: Additional arguments (config, etc.)
Returns:
NodeOutput: Batch with computed rewards
"""
# Option 1: Use built-in reward computation as base
from siirl.execution.scheduler.reward import compute_reward
reward_output = compute_reward(batch, kwargs.get("config"))
# Option 2: Fully custom reward logic
# responses = batch.non_tensor_batch.get("responses", [])
# custom_rewards = np.array([score_response(r) for r in responses])
# batch.batch["rewards"] = custom_rewards
# reward_output = NodeOutput(batch=batch, metrics={"avg_reward": custom_rewards.mean()})
return reward_output
================================================
FILE: examples/custom_reward/rewardfunc_gsm8k.py
================================================
import re
def extract_solution(solution_str, method="strict"):
assert method in ["strict", "flexible"]
if method == "strict":
# this also tests the formatting of the model
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
if solution is None:
final_answer = None
else:
final_answer = solution.group(0)
final_answer = final_answer.split("#### ")[1].replace(",", "").replace("$", "")
elif method == "flexible":
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
final_answer = None
if len(answer) == 0:
# no reward is there is no answer
pass
else:
invalid_str = ["", "."]
# find the last number that is not '.'
for final_answer in reversed(answer):
if final_answer not in invalid_str:
break
return final_answer
def compute_score(data_sources, solution_strs, ground_truths, extra_infos, method="strict", format_score=0.0, score=1.0, **kwargs):
"""The scoring function for GSM8k.
Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
Args:
data_sources: a list of data sources
solution_strs: a list of solution texts
ground_truths: a list of ground truths
extra_infos: a list of extra infos
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
scores = []
for solution_str, ground_truth in zip(solution_strs, ground_truths):
answer = extract_solution(solution_str=solution_str, method=method)
if answer is None:
scores.append(0)
else:
if answer == ground_truth:
scores.append(score)
else:
scores.append(format_score)
return scores
================================================
FILE: examples/custom_reward/run_qwen2_5-7b-custom_reward.sh
================================================
#!/usr/bin/env bash
# Exit immediately if a command exits with a non-zero status.
set -e
set -o pipefail
# Print commands and their arguments as they are executed for easy debugging.
set -x
# --- Environment Setup ---
# bash /root/install_siirl.sh
# Generate a timestamp for unique directory/file names.
timestamp=$(date +"%Y%m%d_%H%M%S")
# Force stop any existing Ray cluster to ensure a clean start.
ray stop --force
# --- Path and Environment Variable Definitions ---
# Define environment variables for data, model, and checkpoint storage paths.
export DATASET=gsm8k
export ALG=grpo
export MODEL_NAME=qwen2.5-7b
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct
export CKPT_PATH=ckpts/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_$PET_NNODES
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
# Environment variables for Gloo (used for distributed communication).
#export GLOO_SOCKET_IFNAME=bond0
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
# Define paths for TensorBoard and logging outputs.
export TENSORBOARD_DIR=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${PET_NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${PET_NNODES}_$timestamp
# --- Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=1024
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=16
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Cluster Configuration (Usually no changes needed below) ---
# These variables are typically set by the cluster job scheduler (e.g., Slurm, DLC).
export N_GPUS_PER_NODE=8
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
export VLLM_USE_V1=1
# Calculate the global batch sizes based on the number of nodes.
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# Ray cluster connection settings.
export RAY_MASTER_PORT=6379
export RAY_DASHBOARD_PORT=8265
export RAY_MASTER_ADDR=$MASTER_ADDR
# --- Ray Cluster Start Function (Robust for Large Scale) ---
start_ray_cluster() {
# Set a generous timeout for workers waiting for the head node.
local RAY_HEAD_WAIT_TIMEOUT=600 # 10 minutes
# For stability in large clusters, explicitly set Ray to use the same network interface.
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=$INTERFACE_NAME
export RAY_GCS_SERVER_CONFIG_NIC_NAME=$INTERFACE_NAME
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
# Increase Ray GCS client connection timeout for stability.
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
# Multi-node environment
if [ "$NNODES" -gt 1 ]; then
# Head node logic (rank 0)
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
# The head's address is its own resolved IP
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head \
--port="$RAY_MASTER_PORT" \
--dashboard-port="$RAY_DASHBOARD_PORT" \
"${ray_start_common_opts[@]}" \
--system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
echo "INFO: Ray head started. Waiting for services to become healthy at $RAY_ADDRESS..."
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
local current_time=$(date +%s)
local elapsed_time=$((current_time - start_time))
if [ "$elapsed_time" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then
echo "ERROR: Timed out after ${RAY_HEAD_WAIT_TIMEOUT}s waiting for the local head node services. Exiting." >&2
ray stop --force
exit 1
fi
echo "Head node services not healthy yet. Retrying in 5 seconds..."
sleep 5
done
echo "INFO: Head node services are healthy."
# Worker node logic (all other ranks)
else
# The address to connect to is the master node's address from the job scheduler
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head node at $head_node_address..."
local start_time=$(date +%s)
# ROBUST CHECK: Use `ray health-check` to wait for the head.
while ! ray health-check --address "$head_node_address" &>/dev/null; do
local current_time=$(date +%s)
local elapsed_time=$((current_time - start_time))
if [ "$elapsed_time" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then
echo "ERROR: Timed out after ${RAY_HEAD_WAIT_TIMEOUT}s waiting for the head node to be healthy. Exiting." >&2
exit 1
fi
echo "Head node at $head_node_address not healthy yet. Retrying in 5 seconds..."
sleep 5
done
echo "INFO: Head node is healthy! Worker node $(hostname) is starting and connecting."
ray start --address="$head_node_address" \
"${ray_start_common_opts[@]}" \
--block # Use --block to keep the script running until the worker is stopped.
fi
# Single-node setup
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Training Launch Function ---
start_training() {
if [ "$NODE_RANK" = "0" ]; then
python3 -m siirl.main_dag \
algorithm.adv_estimator=grpo \
data.train_files=$TRAIN_DATA_PATH \
data.val_files=$TEST_DATA_PATH \
data.train_batch_size=$TRAIN_BATCH_SIZE \
data.max_prompt_length=$MAX_PROMPT_LENGTH \
data.max_response_length=$MAX_RESPONSE_LENGTH \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.shuffle=False \
actor_rollout_ref.model.path=$MODEL_PATH \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.policy_drift_coeff=0.001 \
actor_rollout_ref.actor.use_cpgd_loss=True \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_fused_kernels=False \
actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATCH_SIZE \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.grad_clip=0.5 \
actor_rollout_ref.actor.clip_ratio=0.2 \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=$ROLLOUT_GPU_MEMORY_UTILIZATION \
actor_rollout_ref.rollout.max_model_len=8192 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=$ROLLOUT_N \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','tensorboard'] \
trainer.project_name=$PROJECT_NAME \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.n_gpus_per_node=$N_GPUS_PER_NODE \
trainer.nnodes=$NNODES \
trainer.save_freq=$SAVE_FREQ \
trainer.test_freq=$TEST_FREQ \
trainer.total_epochs=$TOTAL_EPOCHS \
trainer.resume_mode=auto \
trainer.max_actor_ckpt_to_keep=$MAX_CKPT_KEEP \
trainer.default_local_dir=$CKPT_PATH \
trainer.val_before_train=True \
custom_reward_function.path=$HOME/rl/rewardfunc_gsm8k.py \
custom_reward_function.name=compute_score \
reward_model.reward_manager=batch $@
fi
}
# --- Main Execution Logic ---
# Start the Ray cluster (handles both single and multi-node cases).
start_ray_cluster
# This logic should only run on the head node (NODE_RANK=0) in a multi-node setup.
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Head node is up. Waiting for all $NNODES nodes to join the cluster..."
TIMEOUT_SECONDS=600
# This command gets the list of nodes in JSON format and parses it with Python to count them.
# 'ray list nodes' is the correct and modern way to get this information from the CLI.
get_ready_nodes_cmd='ray list nodes --limit=5000 --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))"'
start_time=$(date +%s)
# Loop until the number of ready nodes equals the expected number of nodes.
while true; do
# --- Timeout Check ---
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ "$elapsed_time" -ge "$TIMEOUT_SECONDS" ]; then
echo "Error: Timeout! Waited for ${TIMEOUT_SECONDS} seconds, but not all nodes joined." >&2
exit 1 # Exit with an error code
fi
# Execute the command to get the current count of ready nodes.
# '2>/dev/null' suppresses errors if the ray client isn't ready yet, preventing script failure.
ready_nodes=$(eval "$get_ready_nodes_cmd" 2>/dev/null) || ready_nodes=0
if [ "$ready_nodes" -ge "$NNODES" ]; then
break # All nodes have joined, exit the loop.
fi
echo "Waiting for all worker nodes to register... ($ready_nodes / $NNODES nodes ready)"
sleep 2
done
echo "All $NNODES nodes have successfully joined the cluster."
fi
# --- Script Continuation ---
echo "Node initialization complete. Continuing with main task..."
start_training
================================================
FILE: examples/dapo_trainer/run_qwen2_5-7b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# === DAPO ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=dapo-math-17k
export ALG=grpo # DAPO uses GRPO (Group Relative Policy Optimization) as the base algorithm
export MODEL_NAME=qwen2.5-7b
# --- Path Definitions ---
# export HOME={your_home_path}
# export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
# export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
# export MODEL_PATH=$HOME/data/models/Qwen2.5-VL-7B-Instruct
export TRAIN_DATA_PATH=/inspire/hdd/project/qianghuaxuexi/public/datasets/DAPO-Math-17k/dapo-math-17k.parquet
export TEST_DATA_PATH=/inspire/hdd/project/qianghuaxuexi/public/datasets/gsm8k/test.parquet
export MODEL_PATH=/inspire/hdd/project/qianghuaxuexi/public/models/Qwen3-1.7B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export INFER_MICRO_BATCH_SIZE=8
export TRAIN_MICRO_BATCH_SIZE=8
export OFFLOAD=False
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.7
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=dapo_${DATASET}_${ALG}
export EXPERIMENT_NAME=dapo_${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
export MAX_NUM_TOKEN_PER_GPU=$(($MAX_PROMPT_LENGTH + $MAX_RESPONSE_LENGTH))
# --- DAPO-specific Hyperparameters ---
# Filter groups: Enable dynamic sampling based on trajectory variance
export ENABLE_FILTER_GROUPS=True
export FILTER_GROUPS_METRIC=acc # Metric used for filtering (accuracy)
export MAX_NUM_GEN_BATCHES=10 # Maximum generation batches before giving up
export GEN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * 3)) # Generation batch size (3x training batch per node)
# KL divergence control
export USE_KL_IN_REWARD=False # Whether to use KL penalty in reward
export KL_COEF=0.0 # KL coefficient for reward penalty
export USE_KL_LOSS=False # Whether to use KL loss in actor training
export KL_LOSS_COEF=0.0 # KL loss coefficient
# PPO clipping parameters for DAPO
export CLIP_RATIO_LOW=0.2 # Lower bound for PPO clipping
export CLIP_RATIO_HIGH=0.28 # Upper bound for PPO clipping
export LOSS_AGG_MODE="token-mean" # Loss aggregation mode
# Overlong sequence handling
export ENABLE_OVERLONG_BUFFER=True
export OVERLONG_BUFFER_LEN=512
export OVERLONG_PENALTY_FACTOR=1.0
# Sampling parameters
export TEMPERATURE=1.0 # Sampling temperature
export TOP_P=1.0 # Top-p sampling
export TOP_K=-1 # Top-k sampling (-1 for disabled)
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.gen_batch_size=\$GEN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='left'
data.shuffle=False
data.prompt_key=prompt
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.optim.lr_warmup_steps=10
actor_rollout_ref.actor.optim.weight_decay=0.1
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$TRAIN_MICRO_BATCH_SIZE
actor_rollout_ref.actor.use_kl_loss=\$USE_KL_LOSS
actor_rollout_ref.actor.kl_loss_coef=\$KL_LOSS_COEF
actor_rollout_ref.actor.grad_clip=1.0
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_c=10.0
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.loss_agg_mode=\$LOSS_AGG_MODE
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=\$OFFLOAD
actor_rollout_ref.actor.fsdp_config.optimizer_offload=\$OFFLOAD
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$INFER_MICRO_BATCH_SIZE
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.max_num_batched_tokens=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))
actor_rollout_ref.rollout.max_model_len=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.top_p=\$TOP_P
actor_rollout_ref.rollout.top_k=\$TOP_K
actor_rollout_ref.rollout.val_kwargs.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.val_kwargs.top_p=\$TOP_P
actor_rollout_ref.rollout.val_kwargs.top_k=\$TOP_K
actor_rollout_ref.rollout.val_kwargs.do_sample=True
actor_rollout_ref.rollout.val_kwargs.n=1
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$INFER_MICRO_BATCH_SIZE
actor_rollout_ref.ref.fsdp_config.param_offload=\$OFFLOAD
algorithm.workflow_type=dapo
algorithm.use_kl_in_reward=\$USE_KL_IN_REWARD
algorithm.kl_ctrl.kl_coef=\$KL_COEF
algorithm.filter_groups.enable=\$ENABLE_FILTER_GROUPS
algorithm.filter_groups.metric=\$FILTER_GROUPS_METRIC
algorithm.filter_groups.max_num_gen_batches=\$MAX_NUM_GEN_BATCHES
reward_model.reward_manager=dapo
reward_model.overlong_buffer.enable=\$ENABLE_OVERLONG_BUFFER
reward_model.overlong_buffer.len=\$OVERLONG_BUFFER_LEN
reward_model.overlong_buffer.penalty_factor=\$OVERLONG_PENALTY_FACTOR
trainer.critic_warmup=0
trainer.logger=["console","tensorboard"]
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/dapo_trainer/run_qwen3-235b-megatron-gspo.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- For config debugging
export HYDRA_FULL_ERROR=0
export SIIRL_LOG_VERBOSITY=INFO
export RAY_DEDUP_LOGS=1
# --- Experiment and Model Definition ---
export DATASET=DAPO-Math-17k
export MODEL_NAME=qwen3-235b-a22b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/dapo-math-17k.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-235B-A22B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
export CUDA_DEVICE_MAX_CONNECTIONS=1
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=32 # Conservative for 235B
export PPO_MINI_BATCH_SIZE_PER_NODE=32
export PPO_MICRO_BATCH_SIZE_PER_GPU=4
export MAX_PROMPT_LENGTH=$((1024 * 2))
export MAX_RESPONSE_LENGTH=$((1024 * 8))
export MAX_MODEL_LENGTH=$((1024 * 10))
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.4 # Conservative for 235B
export ROLLOUT_TP=16 # High TP for 235B
export ROLLOUT_N=16
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=15
export MAX_CKPT_KEEP=5
# --- GSPO Specific Parameters ---
export LOSS_MODE=gspo
export ADV_ESTIMATOR=grpo
export CLIP_RATIO_LOW=3e-4
export CLIP_RATIO_HIGH=4e-4
export CLIP_RATIO_C=10.0
export LOSS_AGG_MODE="token-mean"
# --- DAPO-specific Hyperparameters ---
# Filter groups: Enable dynamic sampling based on trajectory variance
export ENABLE_FILTER_GROUPS=True
export FILTER_GROUPS_METRIC=acc # Metric used for filtering (accuracy)
export MAX_NUM_GEN_BATCHES=10 # Maximum generation batches before giving up
export GEN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * 3)) # Generation batch size (3x training batch per node)
# Overlong sequence handling
export ENABLE_OVERLONG_BUFFER=True
export OVERLONG_BUFFER_LEN=$((1024 * 4))
export OVERLONG_PENALTY_FACTOR=1.0
# Sampling parameters
export TEMPERATURE=1.0 # Sampling temperature
export TOP_P=1.0 # Top-p sampling
export TOP_K=-1 # Top-k sampling (-1 for disabled)
# --- KL Configuration ---
export USE_KL_IN_REWARD=False
export KL_COEF=0.0
export USE_KL_LOSS=False
export KL_LOSS_COEF=0.0
export KL_LOSS_TYPE=low_var_kl
# --- Megatron Parallelism for 235B ---
export ACTOR_REF_PP=8 # High pipeline parallel for 235B
export ACTOR_REF_TP=1 # Low tensor parallel
export ACTOR_REF_EP=8 # High expert parallel for MoE
export ACTOR_REF_CP=1 # Context parallel
export ACTOR_REF_SP=True # Sequence parallel
export use_dynamic_bsz=False
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed
# export GLOO_SOCKET_IFNAME=bond0
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${MODEL_NAME}
export EXPERIMENT_NAME=siirl_moe_megatron_${MODEL_NAME}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.workflow_type=dapo
algorithm.adv_estimator=\$ADV_ESTIMATOR
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.gen_batch_size=\$GEN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=True
actor_rollout_ref.model.trust_remote_code=True
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.use_dynamic_bsz=\$use_dynamic_bsz
# GSPO specific loss configuration
actor_rollout_ref.actor.policy_loss.loss_mode=\$LOSS_MODE
actor_rollout_ref.actor.loss_agg_mode=\$LOSS_AGG_MODE
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_c=\$CLIP_RATIO_C
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=\$use_dynamic_bsz
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=\$use_dynamic_bsz
# Megatron configuration for actor (235B)
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.actor.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.optimizer_offload=True
actor_rollout_ref.actor.megatron.use_dist_checkpointing=False
actor_rollout_ref.actor.megatron.use_mbridge=True
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
# PPO configuration
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=\$USE_KL_LOSS
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=\$KL_LOSS_COEF
actor_rollout_ref.actor.kl_loss_type=\$KL_LOSS_TYPE
# Rollout configuration (235B)
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=\$MAX_MODEL_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.top_p=\$TOP_P
actor_rollout_ref.rollout.top_k=\$TOP_K
actor_rollout_ref.rollout.val_kwargs.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.val_kwargs.top_p=\$TOP_P
actor_rollout_ref.rollout.val_kwargs.top_k=\$TOP_K
actor_rollout_ref.rollout.val_kwargs.do_sample=True
actor_rollout_ref.rollout.val_kwargs.n=1
# Reference model configuration (235B)
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.ref.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.ref.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.ref.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.ref.megatron.param_offload=True
actor_rollout_ref.ref.megatron.use_dist_checkpointing=False
# Algorithm configuration
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.use_kl_in_reward=\$USE_KL_IN_REWARD
algorithm.kl_ctrl.kl_coef=\$KL_COEF
algorithm.filter_groups.enable=\$ENABLE_FILTER_GROUPS
algorithm.filter_groups.metric=\$FILTER_GROUPS_METRIC
algorithm.filter_groups.max_num_gen_batches=\$MAX_NUM_GEN_BATCHES
reward_model.reward_manager=dapo
reward_model.overlong_buffer.enable=\$ENABLE_OVERLONG_BUFFER
reward_model.overlong_buffer.len=\$OVERLONG_BUFFER_LEN
reward_model.overlong_buffer.penalty_factor=\$OVERLONG_PENALTY_FACTOR
# Trainer configuration
trainer.critic_warmup=0
trainer.logger='["console","tensorboard"]'
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=off
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
dag.enable_perf=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
export NCCL_TIMEOUT=7200
export GLOO_TIMEOUT_SECONDS=7200
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting GSPO training command."
echo "Command: ${TRAINING_CMD[*]}"
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
main "$@"
fi
================================================
FILE: examples/dapo_trainer/run_qwen3-8b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# === DAPO ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=DAPO-Math-17k
export ALG=grpo # DAPO uses GRPO (Group Relative Policy Optimization) as the base algorithm
export MODEL_NAME=qwen3-8b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/dapo-math-17k.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-8B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export INFER_MICRO_BATCH_SIZE=8
export TRAIN_MICRO_BATCH_SIZE=8
export OFFLOAD=False
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=8192
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=dapo_${DATASET}_${ALG}
export EXPERIMENT_NAME=dapo_${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
export MAX_NUM_TOKEN_PER_GPU=$(($MAX_PROMPT_LENGTH + $MAX_RESPONSE_LENGTH))
# --- DAPO-specific Hyperparameters ---
# Filter groups: Enable dynamic sampling based on trajectory variance
export ENABLE_FILTER_GROUPS=True
export FILTER_GROUPS_METRIC=acc # Metric used for filtering (accuracy)
export MAX_NUM_GEN_BATCHES=10 # Maximum generation batches before giving up
export GEN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * 3)) # Generation batch size (3x training batch per node)
# KL divergence control
export USE_KL_IN_REWARD=False # Whether to use KL penalty in reward
export KL_COEF=0.0 # KL coefficient for reward penalty
export USE_KL_LOSS=False # Whether to use KL loss in actor training
export KL_LOSS_COEF=0.0 # KL loss coefficient
# PPO clipping parameters for DAPO
export CLIP_RATIO_LOW=0.2 # Lower bound for PPO clipping
export CLIP_RATIO_HIGH=0.28 # Upper bound for PPO clipping
export LOSS_AGG_MODE="token-mean" # Loss aggregation mode
# Overlong sequence handling
export ENABLE_OVERLONG_BUFFER=True
export OVERLONG_BUFFER_LEN=512
export OVERLONG_PENALTY_FACTOR=1.0
# Sampling parameters
export TEMPERATURE=1.0 # Sampling temperature
export TOP_P=1.0 # Top-p sampling
export TOP_K=-1 # Top-k sampling (-1 for disabled)
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.gen_batch_size=\$GEN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='left'
data.shuffle=False
data.prompt_key=prompt
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.optim.lr_warmup_steps=10
actor_rollout_ref.actor.optim.weight_decay=0.1
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$TRAIN_MICRO_BATCH_SIZE
actor_rollout_ref.actor.use_kl_loss=\$USE_KL_LOSS
actor_rollout_ref.actor.kl_loss_coef=\$KL_LOSS_COEF
actor_rollout_ref.actor.grad_clip=1.0
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_c=10.0
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.loss_agg_mode=\$LOSS_AGG_MODE
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=\$OFFLOAD
actor_rollout_ref.actor.fsdp_config.optimizer_offload=\$OFFLOAD
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$INFER_MICRO_BATCH_SIZE
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.max_num_batched_tokens=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))
actor_rollout_ref.rollout.max_model_len=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.top_p=\$TOP_P
actor_rollout_ref.rollout.top_k=\$TOP_K
actor_rollout_ref.rollout.val_kwargs.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.val_kwargs.top_p=\$TOP_P
actor_rollout_ref.rollout.val_kwargs.top_k=\$TOP_K
actor_rollout_ref.rollout.val_kwargs.do_sample=True
actor_rollout_ref.rollout.val_kwargs.n=1
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$INFER_MICRO_BATCH_SIZE
actor_rollout_ref.ref.fsdp_config.param_offload=\$OFFLOAD
algorithm.workflow_type=dapo
algorithm.use_kl_in_reward=\$USE_KL_IN_REWARD
algorithm.kl_ctrl.kl_coef=\$KL_COEF
algorithm.filter_groups.enable=\$ENABLE_FILTER_GROUPS
algorithm.filter_groups.metric=\$FILTER_GROUPS_METRIC
algorithm.filter_groups.max_num_gen_batches=\$MAX_NUM_GEN_BATCHES
reward_model.reward_manager=dapo
reward_model.overlong_buffer.enable=\$ENABLE_OVERLONG_BUFFER
reward_model.overlong_buffer.len=\$OVERLONG_BUFFER_LEN
reward_model.overlong_buffer.penalty_factor=\$OVERLONG_PENALTY_FACTOR
trainer.critic_warmup=0
trainer.logger=["console","tensorboard"]
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/data_preprocess/deepscaler.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the DeepScaleR dataset to parquet format
"""
import argparse
import json
import os
import datasets
from siirl.utils.extras.hdfs_io import copy, makedirs
def load_json(file_path):
with open(file_path, "r", encoding="utf-8") as file:
dataset = json.load(file)
return dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/deepscaler")
parser.add_argument("--source_dir", default=None)
parser.add_argument("--hdfs_dir", default=None)
parser.add_argument("--seed", default=15)
args = parser.parse_args()
data_source = "agentica-org/DeepScaleR-Preview-Dataset"
instruction_following = "Let's think step by step and output the final within \\boxed{}."
if args.source_dir == None:
args.source_dir = data_source
raw_dataset = datasets.load_dataset("json", data_files=args.source_dir)
full_dataset = raw_dataset["train"]
train_test_split_dataset = full_dataset.train_test_split(test_size=0.1, seed=args.seed)
train_dataset = train_test_split_dataset["train"]
test_dataset = train_test_split_dataset["test"]
def make_map_fn(split_name):
def process_fn(example, idx):
question_raw = example.pop("problem")
answer_raw = example.pop("answer")
question = question_raw + " " + instruction_following
solution = example.pop("solution")
data = {
"data_source": data_source,
"prompt": [
{
"role": "user",
"content": question,
}
],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": answer_raw},
"extra_info": {
"split": split_name,
"index": idx,
"answer": solution,
"question": question_raw,
},
}
return data
return process_fn
processed_train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
processed_test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
processed_train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
processed_test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
================================================
FILE: examples/data_preprocess/geo3k.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the Geometry3k dataset to parquet format
"""
import argparse
import os
import datasets
from siirl.utils.extras.hdfs_io import copy, makedirs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/geo3k")
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
data_source = "hiyouga/geometry3k"
dataset = datasets.load_dataset(data_source)
train_dataset = dataset["train"]
test_dataset = dataset["test"]
instruction_following = (
r"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. "
r"The reasoning process MUST BE enclosed within tags. The final answer MUST BE put in \boxed{}."
)
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
problem = example.pop("problem")
prompt = problem + " " + instruction_following
answer = example.pop("answer")
images = example.pop("images")
data = {
"data_source": data_source,
"prompt": [
{
"role": "user",
"content": prompt,
}
],
"images": images,
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": answer},
"extra_info": {
"split": split,
"index": idx,
"answer": answer,
"question": problem,
},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=8)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True, num_proc=8)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
================================================
FILE: examples/data_preprocess/gsm8k.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the GSM8k dataset to parquet format
"""
import argparse
import os
import re
import datasets
from siirl.utils.extras.hdfs_io import copy, makedirs
def extract_solution(solution_str):
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/gsm8k")
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
data_source = "openai/gsm8k"
dataset = datasets.load_dataset(data_source, "main")
train_dataset = dataset["train"]
test_dataset = dataset["test"]
instruction_following = 'Let\'s think step by step and output the final answer after "####".'
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question_raw = example.pop("question")
question = question_raw + " " + instruction_following
answer_raw = example.pop("answer")
solution = extract_solution(answer_raw)
data = {
"data_source": data_source,
"prompt": [
{
"role": "user",
"content": question,
}
],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {
"split": split,
"index": idx,
"answer": answer_raw,
"question": question_raw,
},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
================================================
FILE: examples/data_preprocess/math_dataset.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the MATH-lighteval dataset to parquet format
"""
import argparse
import os
import datasets
from siirl.utils.extras.hdfs_io import copy, makedirs
from siirl.utils.reward_score.math import last_boxed_only_string, remove_boxed
def extract_solution(solution_str):
return remove_boxed(last_boxed_only_string(solution_str))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/math")
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
# 'lighteval/MATH' is no longer available on huggingface.
# Use mirror repo: DigitalLearningGmbH/MATH-lighteval
data_source = "DigitalLearningGmbH/MATH-lighteval"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = datasets.load_dataset(data_source, trust_remote_code=True)
train_dataset = dataset["train"]
test_dataset = dataset["test"]
instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question = example.pop("problem")
question = question + " " + instruction_following
answer = example.pop("solution")
solution = extract_solution(answer)
data = {
"data_source": data_source,
"prompt": [{"role": "user", "content": question}],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {"split": split, "index": idx},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
================================================
FILE: examples/data_preprocess/mm_eureka.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the MM Eureka dataset to parquet format
"""
import argparse
import os
from datasets import load_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--jsonl_file", type=str)
parser.add_argument("--output_dir", type=str, default="~/data/mm_eureka/")
parser.add_argument("--dataset_name", type=str, default="mm_eureka")
parser.add_argument("--nproc", type=int, default=16)
parser.add_argument("--test_split", type=int, default=5, help="split percentage of test set")
args = parser.parse_args()
dataset_name = args.dataset_name
nproc = args.nproc
instruct_prompt = "You should first thinks about the reasoning process in the mind and then provides the user with the answer."
instruction_following = (
r"You should first thinks about the reasoning process in the mind and then provides the user with the answer. "
r"Your answer must be in latex format and wrapped in $...$.The reasoning process and answer are enclosed within "
r"and tags, respectively, i.e., Since $1+1=2$, so the answer is $2$. $2$ , "
r"which means your output should start with and end with ."
)
test_split = args.test_split
assert test_split > 0 and test_split < 100
train_dataset = load_dataset("json", data_files=args.jsonl_file, split=f"train[:{1 - test_split}%]")
test_dataset = load_dataset("json", data_files=args.jsonl_file, split=f"train[-{test_split}%:]")
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
id = example.pop("id")
conversations = example.pop("conversations")
answer = example.pop("answer")
image_urls = example.pop("image_urls")
prompts = []
for conv in conversations:
if conv["role"] == "user":
if instruct_prompt not in conv["content"]:
conv["content"] = instruction_following + " " + conv["content"]
prompts.append(conv)
# skip other roles such as "assistant", "system", etc.
images = []
for image_url in image_urls:
with open(image_url, "rb") as f:
images.append({"path": image_url, "bytes": f.read()})
data = {
"data_source": dataset_name,
"prompt": prompts,
"images": images,
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": answer},
"extra_info": {
"id": id,
"split": split,
"index": idx,
"answer": answer,
},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=nproc)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True, num_proc=nproc)
train_file = os.path.join(args.output_dir, "train.parquet")
test_file = os.path.join(args.output_dir, "test.parquet")
train_dataset.to_parquet(train_file)
print(f"Write Done. train file: {train_file}")
test_dataset.to_parquet(test_file)
print(f"Write Done. test file: {test_file}")
================================================
FILE: examples/embodied_srpo_trainer/run_openvla_oft_libero_goal.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-GOAL ===
# ===================================================================================
#
set -e
# --- Environment Setup (Critical for siiRL) ---
export SIIRL_DIR="${SIIRL_DIR:your_siirl_path}"
export PYTHONPATH="$SIIRL_DIR:/root/LIBERO/:your_vjepa2_path:$PYTHONPATH"
# --- Experiment and Model Definition ---
export DATASET=libero_goal
export ALG=srpo
export MODEL_NAME=openvla-oft-7b
export MODEL_TYPE=openvla-oft
# --- Path Definitions (USER PROVIDED) ---
export HOME_PATH=${HOME_PATH:your_home_path}
export TRAIN_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/test.parquet
export MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-goal
export VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Embodied AI Specific Parameters ---
export ACTION_TOKEN_LEN=7 # 7 dimensions: xyz (3), quaternion (3), gripper (1)
export ACTION_CHUNKS_LEN=8 # OpenVLA-OFT uses 8-step action chunks
export NUM_ENVS=16 # actor_rollout_ref.embodied.env.num_envs
export MAX_EPISODE_STEPS=512 # actor_rollout_ref.embodied.env.max_steps
# --- Data and Sampling Parameters ---
export VAL_BATCH_SIZE=496 # Validation batch size
export MAX_PROMPT_LENGTH=256
export MAX_RESPONSE_LENGTH=128
# --- Embodied Sampling Parameters ---
export FILTER_ACCURACY=True # Enable accuracy-based filtering
export ACCURACY_LOWER_BOUND=0.1 # Only keep prompts with success rate >= 0.1
export ACCURACY_UPPER_BOUND=0.9 # Only keep prompts with success rate <= 0.9
export FILTER_TRUNCATED=False # Filter truncated episodes (uses env.max_steps)
export OVERSAMPLE_FACTOR=1 # Oversample factor for filtering
# --- Training Hyperparameters ---
export TRAIN_BATCH_SIZE=64 # data.train_batch_size
export PPO_MINI_BATCH_SIZE=4 # actor_rollout_ref.actor.ppo_mini_batch_size
# Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES
export ROLLOUT_N_SAMPLES=8 # REUSED: Number of samples per prompt
export PPO_EPOCHS=1 # actor_rollout_ref.actor.ppo_epochs
# Algorithm parameters
export LEARNING_RATE=5e-6
export WEIGHT_DECAY=0.0 # actor_rollout_ref.actor.optim.weight_decay
export CLIP_RATIO_HIGH=0.28 # actor_rollout_ref.actor.clip_ratio_high
export CLIP_RATIO_LOW=0.2 # actor_rollout_ref.actor.clip_ratio_low
export ENTROPY_COEFF=0.0
export TEMPERATURE=1.6
export GAMMA=1.0
export LAM=1.0
export GRAD_CLIP=1.0
# --- Image/Video Processing ---
export IMG_SIZE=384 # actor_rollout_ref.embodied.img_size
export ENABLE_FP16=True # actor_rollout_ref.embodied.enable_fp16
export EMBEDDING_MODEL_OFFLOAD=False # actor_rollout_ref.embodied.embedding_model_offload
export CENTER_CROP=True # actor_rollout_ref.embodied.center_crop
export NUM_IMAGES_IN_INPUT=1
export NUM_STEPS_WAIT=10 # Environment stabilization steps
# --- Trainer Configuration ---
export SAVE_FREQ=5
export TEST_FREQ=5
export TOTAL_EPOCHS=1000 # trainer.total_epochs
export MAX_CKPT_KEEP=5 # trainer.max_actor_ckpt_to_keep
export VAL_BEFORE_TRAIN=True # trainer.val_before_train
# --- Multi-node distributed training ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
export MASTER_PORT=${MASTER_PORT:-29500}
# --- Environment Variables ---
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export GLOO_SOCKET_TIMEOUT=600
# --- Output Paths and Experiment Naming ---
timestamp=$(date +%Y%m%d_%H%M%S)
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes
export PROJECT_NAME=siirl_embodied_${DATASET}
export EXPERIMENT_NAME=openvla_oft_srpo_fsdp
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}
# --- Define the Training Command ---
TRAINING_CMD=(
python3 -m siirl.main_dag
--config-name=embodied_grpo_trainer
# Data configuration
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.val_batch_size=\$VAL_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.dataset_type=embodied
# Reward
reward_model.reward_manager=embodied
reward_model.reward_kwargs.action_token_len=7
reward_model.reward_kwargs.reward_coef=5.0
# Algorithm configuration
algorithm.workflow_type=embodied
algorithm.adv_estimator=grpo
algorithm.gamma=\$GAMMA
algorithm.lam=\$LAM
algorithm.norm_adv_by_std_in_grpo=True
# Embodied sampling configuration (aligned with DAPO architecture)
algorithm.filter_groups.enable=True
algorithm.embodied_sampling.filter_accuracy=\$FILTER_ACCURACY
algorithm.embodied_sampling.accuracy_lower_bound=\$ACCURACY_LOWER_BOUND
algorithm.embodied_sampling.accuracy_upper_bound=\$ACCURACY_UPPER_BOUND
algorithm.embodied_sampling.filter_truncated=\$FILTER_TRUNCATED
algorithm.embodied_sampling.oversample_factor=\$OVERSAMPLE_FACTOR
# Model configuration
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.model.model_type=embodied
actor_rollout_ref.model.trust_remote_code=True
# Actor configuration
actor_rollout_ref.actor.optim.lr=\$LEARNING_RATE
actor_rollout_ref.actor.optim.weight_decay=\$WEIGHT_DECAY
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_epochs=\$PPO_EPOCHS
actor_rollout_ref.actor.grad_clip=\$GRAD_CLIP
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.entropy_coeff=\$ENTROPY_COEFF
actor_rollout_ref.actor.shuffle=True
# Actor FSDP configuration
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.grad_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
# Rollout configuration
actor_rollout_ref.rollout.name=hf
actor_rollout_ref.rollout.n=\$ROLLOUT_N_SAMPLES
actor_rollout_ref.rollout.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.do_sample=True
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=512
# Embodied AI specific configuration
actor_rollout_ref.embodied.embodied_type=\$MODEL_TYPE
actor_rollout_ref.embodied.action_token_len=\$ACTION_TOKEN_LEN
actor_rollout_ref.embodied.action_chunks_len=\$ACTION_CHUNKS_LEN
actor_rollout_ref.embodied.video_embedding_model_path=\$VJEPA_MODEL_PATH
actor_rollout_ref.embodied.embedding_img_size=\$IMG_SIZE
actor_rollout_ref.embodied.embedding_enable_fp16=\$ENABLE_FP16
actor_rollout_ref.embodied.embedding_model_offload=\$EMBEDDING_MODEL_OFFLOAD
actor_rollout_ref.embodied.center_crop=\$CENTER_CROP
actor_rollout_ref.embodied.num_images_in_input=\$NUM_IMAGES_IN_INPUT
actor_rollout_ref.embodied.unnorm_key=\$DATASET
# Environment configuration
actor_rollout_ref.embodied.env.env_type=libero
actor_rollout_ref.embodied.env.env_name=\$DATASET
actor_rollout_ref.embodied.env.num_envs=\$NUM_ENVS
actor_rollout_ref.embodied.env.max_steps=\$MAX_EPISODE_STEPS
actor_rollout_ref.embodied.env.num_steps_wait=\$NUM_STEPS_WAIT
actor_rollout_ref.embodied.env.num_trials_per_task=50
actor_rollout_ref.embodied.env.model_family=openvla
# Critic configuration (SRPO doesn't use critic)
critic.use_critic_model=False
# Trainer configuration
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.nnodes=\$NNODES
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.default_local_dir=\$CKPT_PATH
trainer.resume_mode=auto
trainer.val_before_train=\$VAL_BEFORE_TRAIN
)
# ===================================================================================
# === EXECUTION LOGIC ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/embodied_srpo_trainer/run_openvla_oft_libero_long.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-10 ===
# ===================================================================================
#
set -e
# --- Environment Setup (Critical for siiRL) ---
export SIIRL_DIR="${SIIRL_DIR:your_siirl_path}"
export PYTHONPATH="$SIIRL_DIR:/root/LIBERO/:your_vjepa2_path:$PYTHONPATH"
# --- Experiment and Model Definition ---
export DATASET=libero_10
export ALG=srpo
export MODEL_NAME=openvla-oft-7b
export MODEL_TYPE=openvla-oft
# --- Path Definitions (USER PROVIDED) ---
export HOME_PATH=${HOME_PATH:your_home_path}
export TRAIN_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/test.parquet
export MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-long
export VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Embodied AI Specific Parameters ---
export ACTION_TOKEN_LEN=7 # 7 dimensions: xyz (3), quaternion (3), gripper (1)
export ACTION_CHUNKS_LEN=8 # OpenVLA-OFT uses 8-step action chunks
export NUM_ENVS=16 # actor_rollout_ref.embodied.env.num_envs
export MAX_EPISODE_STEPS=512 # actor_rollout_ref.embodied.env.max_steps
# --- Data and Sampling Parameters ---
export VAL_BATCH_SIZE=496 # Validation batch size
export MAX_PROMPT_LENGTH=256
export MAX_RESPONSE_LENGTH=128
# --- Embodied Sampling Parameters ---
export FILTER_ACCURACY=True # Enable accuracy-based filtering
export ACCURACY_LOWER_BOUND=0.1 # Only keep prompts with success rate >= 0.1
export ACCURACY_UPPER_BOUND=0.9 # Only keep prompts with success rate <= 0.9
export FILTER_TRUNCATED=False # Filter truncated episodes (uses env.max_steps)
export OVERSAMPLE_FACTOR=1 # Oversample factor for filtering
# --- Training Hyperparameters ---
export TRAIN_BATCH_SIZE=64 # data.train_batch_size
export PPO_MINI_BATCH_SIZE=4 # actor_rollout_ref.actor.ppo_mini_batch_size
# Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES
export ROLLOUT_N_SAMPLES=8 # REUSED: Number of samples per prompt
export PPO_EPOCHS=1 # actor_rollout_ref.actor.ppo_epochs
# Algorithm parameters
export LEARNING_RATE=5e-6
export WEIGHT_DECAY=0.0 # actor_rollout_ref.actor.optim.weight_decay
export CLIP_RATIO_HIGH=0.28 # actor_rollout_ref.actor.clip_ratio_high
export CLIP_RATIO_LOW=0.2 # actor_rollout_ref.actor.clip_ratio_low
export ENTROPY_COEFF=0.0
export TEMPERATURE=1.6
export GAMMA=1.0
export LAM=1.0
export GRAD_CLIP=1.0
# --- Image/Video Processing ---
export IMG_SIZE=384 # actor_rollout_ref.embodied.img_size
export ENABLE_FP16=True # actor_rollout_ref.embodied.enable_fp16
export EMBEDDING_MODEL_OFFLOAD=False # actor_rollout_ref.embodied.embedding_model_offload
export CENTER_CROP=True # actor_rollout_ref.embodied.center_crop
export NUM_IMAGES_IN_INPUT=1
export NUM_STEPS_WAIT=10 # Environment stabilization steps
# --- Trainer Configuration ---
export SAVE_FREQ=5
export TEST_FREQ=5
export TOTAL_EPOCHS=1000 # trainer.total_epochs
export MAX_CKPT_KEEP=5 # trainer.max_actor_ckpt_to_keep
export VAL_BEFORE_TRAIN=True # trainer.val_before_train
# --- Multi-node distributed training ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
export MASTER_PORT=${MASTER_PORT:-29500}
# --- Environment Variables ---
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export GLOO_SOCKET_TIMEOUT=600
# --- Output Paths and Experiment Naming ---
timestamp=$(date +%Y%m%d_%H%M%S)
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes
export PROJECT_NAME=siirl_embodied_${DATASET}
export EXPERIMENT_NAME=openvla_oft_srpo_fsdp
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}
# --- Define the Training Command ---
TRAINING_CMD=(
python3 -m siirl.main_dag
# Data configuration
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.val_batch_size=\$VAL_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.dataset_type=embodied
# Reward
reward_model.reward_manager=embodied
reward_model.reward_kwargs.action_token_len=7
reward_model.reward_kwargs.reward_coef=5.0
# Algorithm configuration
algorithm.workflow_type=embodied
algorithm.adv_estimator=grpo
algorithm.gamma=\$GAMMA
algorithm.lam=\$LAM
algorithm.norm_adv_by_std_in_grpo=True
# Embodied sampling configuration (aligned with DAPO architecture)
algorithm.filter_groups.enable=True
algorithm.embodied_sampling.filter_accuracy=\$FILTER_ACCURACY
algorithm.embodied_sampling.accuracy_lower_bound=\$ACCURACY_LOWER_BOUND
algorithm.embodied_sampling.accuracy_upper_bound=\$ACCURACY_UPPER_BOUND
algorithm.embodied_sampling.filter_truncated=\$FILTER_TRUNCATED
algorithm.embodied_sampling.oversample_factor=\$OVERSAMPLE_FACTOR
# Model configuration
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.model.model_type=embodied
actor_rollout_ref.model.trust_remote_code=True
# Actor configuration
actor_rollout_ref.actor.optim.lr=\$LEARNING_RATE
actor_rollout_ref.actor.optim.weight_decay=\$WEIGHT_DECAY
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_epochs=\$PPO_EPOCHS
actor_rollout_ref.actor.grad_clip=\$GRAD_CLIP
actor_rollout_ref.actor.clip_ratio_c=10000.0
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.entropy_coeff=\$ENTROPY_COEFF
actor_rollout_ref.actor.shuffle=True
# Actor FSDP configuration
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.grad_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
# Rollout configuration
actor_rollout_ref.rollout.name=hf
actor_rollout_ref.rollout.n=\$ROLLOUT_N_SAMPLES
actor_rollout_ref.rollout.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.do_sample=True
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=512
# Embodied AI specific configuration
actor_rollout_ref.embodied.embodied_type=\$MODEL_TYPE
actor_rollout_ref.embodied.action_token_len=\$ACTION_TOKEN_LEN
actor_rollout_ref.embodied.action_chunks_len=\$ACTION_CHUNKS_LEN
actor_rollout_ref.embodied.video_embedding_model_path=\$VJEPA_MODEL_PATH
actor_rollout_ref.embodied.embedding_img_size=\$IMG_SIZE
actor_rollout_ref.embodied.embedding_enable_fp16=\$ENABLE_FP16
actor_rollout_ref.embodied.embedding_model_offload=\$EMBEDDING_MODEL_OFFLOAD
actor_rollout_ref.embodied.center_crop=\$CENTER_CROP
actor_rollout_ref.embodied.num_images_in_input=\$NUM_IMAGES_IN_INPUT
actor_rollout_ref.embodied.unnorm_key=\$DATASET
# Environment configuration
actor_rollout_ref.embodied.env.env_type=libero
actor_rollout_ref.embodied.env.env_name=\$DATASET
actor_rollout_ref.embodied.env.num_envs=\$NUM_ENVS
actor_rollout_ref.embodied.env.max_steps=\$MAX_EPISODE_STEPS
actor_rollout_ref.embodied.env.num_steps_wait=\$NUM_STEPS_WAIT
actor_rollout_ref.embodied.env.num_trials_per_task=50
actor_rollout_ref.embodied.env.model_family=openvla
# Critic configuration (SRPO doesn't use critic)
critic.use_critic_model=False
# Trainer configuration
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.nnodes=\$NNODES
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.default_local_dir=\$CKPT_PATH
trainer.resume_mode=auto
trainer.val_before_train=\$VAL_BEFORE_TRAIN
)
# ===================================================================================
# === EXECUTION LOGIC ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/embodied_srpo_trainer/run_openvla_oft_libero_object.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-OBJECT ===
# ===================================================================================
#
set -e
# --- Environment Setup (Critical for siiRL) ---
export SIIRL_DIR="${SIIRL_DIR:your_siirl_path}"
export PYTHONPATH="$SIIRL_DIR:/root/LIBERO/:your_vjepa2_path:$PYTHONPATH"
# --- Experiment and Model Definition ---
export DATASET=libero_object
export ALG=srpo
export MODEL_NAME=openvla-oft-7b
export MODEL_TYPE=openvla-oft
# --- Path Definitions (USER PROVIDED) ---
export HOME_PATH=${HOME_PATH:your_home_path}
export TRAIN_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/test.parquet
export MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-object
export VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Embodied AI Specific Parameters ---
export ACTION_TOKEN_LEN=7 # 7 dimensions: xyz (3), quaternion (3), gripper (1)
export ACTION_CHUNKS_LEN=8 # OpenVLA-OFT uses 8-step action chunks
export NUM_ENVS=16 # actor_rollout_ref.embodied.env.num_envs
export MAX_EPISODE_STEPS=512 # actor_rollout_ref.embodied.env.max_steps
# --- Data and Sampling Parameters ---
export VAL_BATCH_SIZE=496 # Validation batch size
export MAX_PROMPT_LENGTH=256
export MAX_RESPONSE_LENGTH=128
# --- Embodied Sampling Parameters ---
export FILTER_ACCURACY=True # Enable accuracy-based filtering
export ACCURACY_LOWER_BOUND=0.1 # Only keep prompts with success rate >= 0.1
export ACCURACY_UPPER_BOUND=0.9 # Only keep prompts with success rate <= 0.9
export FILTER_TRUNCATED=False # Filter truncated episodes (uses env.max_steps)
export OVERSAMPLE_FACTOR=1 # Oversample factor for filtering
# --- Training Hyperparameters ---
export TRAIN_BATCH_SIZE=64 # data.train_batch_size
export PPO_MINI_BATCH_SIZE=4 # actor_rollout_ref.actor.ppo_mini_batch_size
# Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES
export ROLLOUT_N_SAMPLES=8 # REUSED: Number of samples per prompt
export PPO_EPOCHS=1 # actor_rollout_ref.actor.ppo_epochs
# Algorithm parameters
export LEARNING_RATE=5e-6
export WEIGHT_DECAY=0.0 # actor_rollout_ref.actor.optim.weight_decay
export CLIP_RATIO_HIGH=0.28 # actor_rollout_ref.actor.clip_ratio_high
export CLIP_RATIO_LOW=0.2 # actor_rollout_ref.actor.clip_ratio_low
export ENTROPY_COEFF=0.0
export TEMPERATURE=1.6
export GAMMA=1.0
export LAM=1.0
export GRAD_CLIP=1.0
# --- Image/Video Processing ---
export IMG_SIZE=384 # actor_rollout_ref.embodied.img_size
export ENABLE_FP16=True # actor_rollout_ref.embodied.enable_fp16
export EMBEDDING_MODEL_OFFLOAD=False # actor_rollout_ref.embodied.embedding_model_offload
export CENTER_CROP=True # actor_rollout_ref.embodied.center_crop
export NUM_IMAGES_IN_INPUT=1
export NUM_STEPS_WAIT=10 # Environment stabilization steps
# --- Trainer Configuration ---
export SAVE_FREQ=5
export TEST_FREQ=5
export TOTAL_EPOCHS=1000 # trainer.total_epochs
export MAX_CKPT_KEEP=5 # trainer.max_actor_ckpt_to_keep
export VAL_BEFORE_TRAIN=True # trainer.val_before_train
# --- Multi-node distributed training ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
export MASTER_PORT=${MASTER_PORT:-29500}
# --- Environment Variables ---
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export GLOO_SOCKET_TIMEOUT=600
# --- Output Paths and Experiment Naming ---
timestamp=$(date +%Y%m%d_%H%M%S)
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes
export PROJECT_NAME=siirl_embodied_${DATASET}
export EXPERIMENT_NAME=openvla_oft_srpo_fsdp
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}
# --- Define the Training Command ---
TRAINING_CMD=(
python3 -m siirl.main_dag
# Data configuration
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.val_batch_size=\$VAL_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.dataset_type=embodied
# Reward
reward_model.reward_manager=embodied
reward_model.reward_kwargs.action_token_len=7
reward_model.reward_kwargs.reward_coef=5.0
# Algorithm configuration
algorithm.workflow_type=embodied
algorithm.adv_estimator=grpo
algorithm.gamma=\$GAMMA
algorithm.lam=\$LAM
algorithm.norm_adv_by_std_in_grpo=True
# Embodied sampling configuration (aligned with DAPO architecture)
algorithm.filter_groups.enable=True
algorithm.embodied_sampling.filter_accuracy=\$FILTER_ACCURACY
algorithm.embodied_sampling.accuracy_lower_bound=\$ACCURACY_LOWER_BOUND
algorithm.embodied_sampling.accuracy_upper_bound=\$ACCURACY_UPPER_BOUND
algorithm.embodied_sampling.filter_truncated=\$FILTER_TRUNCATED
algorithm.embodied_sampling.oversample_factor=\$OVERSAMPLE_FACTOR
# Model configuration
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.model.model_type=embodied
actor_rollout_ref.model.trust_remote_code=True
# Actor configuration
actor_rollout_ref.actor.optim.lr=\$LEARNING_RATE
actor_rollout_ref.actor.optim.weight_decay=\$WEIGHT_DECAY
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_epochs=\$PPO_EPOCHS
actor_rollout_ref.actor.grad_clip=\$GRAD_CLIP
actor_rollout_ref.actor.clip_ratio_c=10000.0
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.entropy_coeff=\$ENTROPY_COEFF
actor_rollout_ref.actor.shuffle=True
# Actor FSDP configuration
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.grad_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
# Rollout configuration
actor_rollout_ref.rollout.name=hf
actor_rollout_ref.rollout.n=\$ROLLOUT_N_SAMPLES
actor_rollout_ref.rollout.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.do_sample=True
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=512
# Embodied AI specific configuration
actor_rollout_ref.embodied.embodied_type=\$MODEL_TYPE
actor_rollout_ref.embodied.action_token_len=\$ACTION_TOKEN_LEN
actor_rollout_ref.embodied.action_chunks_len=\$ACTION_CHUNKS_LEN
actor_rollout_ref.embodied.video_embedding_model_path=\$VJEPA_MODEL_PATH
actor_rollout_ref.embodied.embedding_img_size=\$IMG_SIZE
actor_rollout_ref.embodied.embedding_enable_fp16=\$ENABLE_FP16
actor_rollout_ref.embodied.embedding_model_offload=\$EMBEDDING_MODEL_OFFLOAD
actor_rollout_ref.embodied.center_crop=\$CENTER_CROP
actor_rollout_ref.embodied.num_images_in_input=\$NUM_IMAGES_IN_INPUT
actor_rollout_ref.embodied.unnorm_key=\$DATASET
# Environment configuration
actor_rollout_ref.embodied.env.env_type=libero
actor_rollout_ref.embodied.env.env_name=\$DATASET
actor_rollout_ref.embodied.env.num_envs=\$NUM_ENVS
actor_rollout_ref.embodied.env.max_steps=\$MAX_EPISODE_STEPS
actor_rollout_ref.embodied.env.num_steps_wait=\$NUM_STEPS_WAIT
actor_rollout_ref.embodied.env.num_trials_per_task=50
actor_rollout_ref.embodied.env.model_family=openvla
# Critic configuration (SRPO doesn't use critic)
critic.use_critic_model=False
# Trainer configuration
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.nnodes=\$NNODES
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.default_local_dir=\$CKPT_PATH
trainer.resume_mode=auto
trainer.val_before_train=\$VAL_BEFORE_TRAIN
)
# ===================================================================================
# === EXECUTION LOGIC ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/embodied_srpo_trainer/run_openvla_oft_libero_spatial.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-SPATIAL ===
# ===================================================================================
#
set -e
# --- Environment Setup (Critical for siiRL) ---
export SIIRL_DIR="${SIIRL_DIR:your_siirl_path}"
export PYTHONPATH="$SIIRL_DIR:/root/LIBERO/:your_vjepa2_path:$PYTHONPATH"
# --- Experiment and Model Definition ---
export DATASET=libero_spatial
export ALG=srpo
export MODEL_NAME=openvla-oft-7b
export MODEL_TYPE=openvla-oft
# --- Path Definitions (USER PROVIDED) ---
export HOME_PATH=${HOME_PATH:your_home_path}
export TRAIN_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/test.parquet
export MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-spatial
export VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Embodied AI Specific Parameters ---
export ACTION_TOKEN_LEN=7 # 7 dimensions: xyz (3), quaternion (3), gripper (1)
export ACTION_CHUNKS_LEN=8 # OpenVLA-OFT uses 8-step action chunks
export NUM_ENVS=16 # actor_rollout_ref.embodied.env.num_envs
export MAX_EPISODE_STEPS=512 # actor_rollout_ref.embodied.env.max_steps
# --- Data and Sampling Parameters ---
export VAL_BATCH_SIZE=496 # Validation batch size
export MAX_PROMPT_LENGTH=256
export MAX_RESPONSE_LENGTH=128
# --- Embodied Sampling Parameters ---
export FILTER_ACCURACY=True # Enable accuracy-based filtering
export ACCURACY_LOWER_BOUND=0.1 # Only keep prompts with success rate >= 0.1
export ACCURACY_UPPER_BOUND=0.9 # Only keep prompts with success rate <= 0.9
export FILTER_TRUNCATED=False # Filter truncated episodes (uses env.max_steps)
export OVERSAMPLE_FACTOR=1 # Oversample factor for filtering
# --- Training Hyperparameters ---
export TRAIN_BATCH_SIZE=64 # data.train_batch_size
export PPO_MINI_BATCH_SIZE=4 # actor_rollout_ref.actor.ppo_mini_batch_size
# Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES
export ROLLOUT_N_SAMPLES=8 # REUSED: Number of samples per prompt
export PPO_EPOCHS=1 # actor_rollout_ref.actor.ppo_epochs
# Algorithm parameters
export LEARNING_RATE=5e-6
export WEIGHT_DECAY=0.0 # actor_rollout_ref.actor.optim.weight_decay
export CLIP_RATIO_HIGH=0.28 # actor_rollout_ref.actor.clip_ratio_high
export CLIP_RATIO_LOW=0.2 # actor_rollout_ref.actor.clip_ratio_low
export ENTROPY_COEFF=0.0
export TEMPERATURE=1.6
export GAMMA=1.0
export LAM=1.0
export GRAD_CLIP=1.0
# --- Image/Video Processing ---
export IMG_SIZE=384 # actor_rollout_ref.embodied.img_size
export ENABLE_FP16=True # actor_rollout_ref.embodied.enable_fp16
export EMBEDDING_MODEL_OFFLOAD=False # actor_rollout_ref.embodied.embedding_model_offload
export CENTER_CROP=True # ctor_rollout_ref.embodied.center_crop
export NUM_IMAGES_IN_INPUT=1
export NUM_STEPS_WAIT=10 # Environment stabilization steps
# --- Trainer Configuration ---
export SAVE_FREQ=5
export TEST_FREQ=5
export TOTAL_EPOCHS=1000 # trainer.total_epochs
export MAX_CKPT_KEEP=5 # trainer.max_actor_ckpt_to_keep
export VAL_BEFORE_TRAIN=True # trainer.val_before_train
# --- Multi-node distributed training ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
export MASTER_PORT=${MASTER_PORT:-29500}
# --- Environment Variables ---
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export GLOO_SOCKET_TIMEOUT=600
# --- Output Paths and Experiment Naming ---
timestamp=$(date +%Y%m%d_%H%M%S)
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes
export PROJECT_NAME=siirl_embodied_${DATASET}
export EXPERIMENT_NAME=openvla_oft_srpo_fsdp
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}
# --- Define the Training Command ---
TRAINING_CMD=(
python3 -m siirl.main_dag
# Data configuration
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.val_batch_size=\$VAL_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.dataset_type=embodied
# Reward
reward_model.reward_manager=embodied
reward_model.reward_kwargs.action_token_len=7
reward_model.reward_kwargs.reward_coef=5.0
# Algorithm configuration
algorithm.workflow_type=embodied
algorithm.adv_estimator=grpo
algorithm.gamma=\$GAMMA
algorithm.lam=\$LAM
algorithm.norm_adv_by_std_in_grpo=True
# Embodied sampling configuration (aligned with DAPO architecture)
algorithm.filter_groups.enable=True
algorithm.embodied_sampling.filter_accuracy=\$FILTER_ACCURACY
algorithm.embodied_sampling.accuracy_lower_bound=\$ACCURACY_LOWER_BOUND
algorithm.embodied_sampling.accuracy_upper_bound=\$ACCURACY_UPPER_BOUND
algorithm.embodied_sampling.filter_truncated=\$FILTER_TRUNCATED
algorithm.embodied_sampling.oversample_factor=\$OVERSAMPLE_FACTOR
# Model configuration
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.model.model_type=embodied
actor_rollout_ref.model.trust_remote_code=True
# Actor configuration
actor_rollout_ref.actor.optim.lr=\$LEARNING_RATE
actor_rollout_ref.actor.optim.weight_decay=\$WEIGHT_DECAY
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_epochs=\$PPO_EPOCHS
actor_rollout_ref.actor.grad_clip=\$GRAD_CLIP
actor_rollout_ref.actor.clip_ratio_c=10000.0
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.entropy_coeff=\$ENTROPY_COEFF
actor_rollout_ref.actor.shuffle=True
# Actor FSDP configuration
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.grad_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
# Rollout configuration
actor_rollout_ref.rollout.name=hf
actor_rollout_ref.rollout.n=\$ROLLOUT_N_SAMPLES
actor_rollout_ref.rollout.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.do_sample=True
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=512
# Embodied AI specific configuration
actor_rollout_ref.embodied.embodied_type=\$MODEL_TYPE
actor_rollout_ref.embodied.action_token_len=\$ACTION_TOKEN_LEN
actor_rollout_ref.embodied.action_chunks_len=\$ACTION_CHUNKS_LEN
actor_rollout_ref.embodied.video_embedding_model_path=\$VJEPA_MODEL_PATH
actor_rollout_ref.embodied.embedding_img_size=\$IMG_SIZE
actor_rollout_ref.embodied.embedding_enable_fp16=\$ENABLE_FP16
actor_rollout_ref.embodied.embedding_model_offload=\$EMBEDDING_MODEL_OFFLOAD
actor_rollout_ref.embodied.center_crop=\$CENTER_CROP
actor_rollout_ref.embodied.num_images_in_input=\$NUM_IMAGES_IN_INPUT
actor_rollout_ref.embodied.unnorm_key=\$DATASET
# Environment configuration
actor_rollout_ref.embodied.env.env_type=libero
actor_rollout_ref.embodied.env.env_name=\$DATASET
actor_rollout_ref.embodied.env.num_envs=\$NUM_ENVS
actor_rollout_ref.embodied.env.max_steps=\$MAX_EPISODE_STEPS
actor_rollout_ref.embodied.env.num_steps_wait=\$NUM_STEPS_WAIT
actor_rollout_ref.embodied.env.num_trials_per_task=50
actor_rollout_ref.embodied.env.model_family=openvla
# Critic configuration (SRPO doesn't use critic)
critic.use_critic_model=False
# Trainer configuration
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.nnodes=\$NNODES
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.default_local_dir=\$CKPT_PATH
trainer.resume_mode=auto
trainer.val_before_train=\$VAL_BEFORE_TRAIN
)
# ===================================================================================
# === EXECUTION LOGIC ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/experimental/marft/config/code_env.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from siirl.utils.reward_score.prime_code import compute_score
from typing import Any, Dict, Optional, Tuple
import asyncio
class CodeEnv():
def __init__(self):
pass
def reset(self) -> Any:
pass
async def step(self, actions, ground_truth):
actor_action = actions[-1]
loop = asyncio.get_event_loop()
score, _ = await loop.run_in_executor(
None,
compute_score,
actor_action, ground_truth
)
score = float(score)
should_stop = False
if score == 1.0:
next_obs = [act + ". This answer is right." for act in actions]
should_stop = True
else:
next_obs = [act + ". This answer is wrong." for act in actions]
return next_obs, score, should_stop
================================================
FILE: examples/experimental/marft/config/math_env.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from siirl.utils.reward_score.math import compute_score
from typing import Any, Dict, Optional, Tuple
class MathEnv():
def __init__(self):
pass
def reset(self) -> Any:
pass
async def step(self, actions, ground_truth):
actor_action = actions[-1]
loop = asyncio.get_event_loop()
score = await loop.run_in_executor(
None,
compute_score,
actor_action, ground_truth
)
should_stop = False
if score == 1.0:
next_obs = [act + " This answer is right." for act in actions]
should_stop = True
else:
next_obs = [act + " This answer is wrong." for act in actions]
return next_obs, score, should_stop
================================================
FILE: examples/experimental/marft/config/process.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from string import Template
def pre_process(tokenizer, prompt_id, obs, **kwargs):
pre_chat_template = Template(kwargs.get("pre_chat_template", ""))
prompt = tokenizer.decode(prompt_id)
prompt = pre_chat_template.substitute(prompt = prompt)
message = [
{"role": "system", "content": ""},
{"role": "user", "content": prompt}
]
return tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, add_special_tokens=False)
def post_process(tokenizer, prompt_id, response_id, **kwargs):
post_chat_template = kwargs.get("post_chat_template", None)
post_chat_template_id = tokenizer.encode(post_chat_template)
return prompt_id + post_chat_template_id + response_id
================================================
FILE: examples/experimental/marft/config/workflow_marft.yaml
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
dag_id: "marft_ppo_training_pipeline"
description: "This is MARFT DAG workflow configured via YAML."
actor_1_config: &actor1_config
rollout.log_prob_micro_batch_size_per_gpu: 16
rollout.tensor_model_parallel_size: 4
rollout.gpu_memory_utilization: 0.3
rollout.n: 1
actor_2_config: &actor2_config
rollout.log_prob_micro_batch_size_per_gpu: 16
rollout.tensor_model_parallel_size: 4
rollout.gpu_memory_utilization: 0.3
rollout.n: 1
nodes:
- node_id: "rollout_reasoner"
node_type: "MODEL_INFERENCE"
node_role: "ROLLOUT"
config: *actor1_config
agent_group: 0
dependencies: []
agent_options:
obs_with_env: true
process_path: examples/experimental/marft/config/process.py
pre_process_kwargs:
pre_chat_template: "<|im_start|>system: Two LLM agents (Reasoner -> Actor) collaborate step-by-step to solve math problems. You are the **Reasoner**: Analyze the original problem, historical actions, and reflection data (if provided) to determine the critical next step. Guide the Actor by providing concise reasoning for the optimal operation.<|im_end|>\n <|im_start|> problem: ${prompt} <|im_end|>\n <|im_start|> reasoner: "
post_process_kwargs:
post_chat_template: " <|im_start|> reasoner: "
- node_id: "rollout_actor"
node_type: "MODEL_INFERENCE"
node_role: "ROLLOUT"
config: *actor2_config
agent_group: 1
dependencies:
- "rollout_reasoner"
agent_options:
obs_with_env: true
process_path: examples/experimental/marft/config/process.py
pre_process_kwargs:
pre_chat_template: "<|im_start|>system: Two LLM agents (Reasoner -> Actor) collaborate step-by-step. You are the **Actor**: Execute operations using original problem, action history, and Reasoner's guidance. Give the final output within \\boxed{}.<|im_end|>\n ${prompt} <|im_start|> actor: "
post_process_kwargs:
post_chat_template: " <|im_start|> actor: "
env_path: [examples/experimental/marft/config/math_env.py:MathEnv]
- node_id: "function_reward"
node_type: "COMPUTE"
node_role: "REWARD"
agent_group: 1
dependencies:
- "rollout_actor"
- node_id: "actor_1_old_log_prob"
node_type: "MODEL_TRAIN"
node_role: "ACTOR"
only_forward_compute: true
agent_group: 0
config: *actor1_config
dependencies:
- "function_reward"
- node_id: "actor_2_old_log_prob"
node_type: "MODEL_TRAIN"
node_role: "ACTOR"
only_forward_compute: true
agent_group: 1
config: *actor2_config
dependencies:
- "actor_1_old_log_prob"
- node_id: "reference_1_log_prob"
node_type: "MODEL_TRAIN"
node_role: "REFERENCE"
agent_group: 0
dependencies:
- "actor_2_old_log_prob"
- node_id: "reference_2_log_prob"
node_type: "MODEL_TRAIN"
node_role: "REFERENCE"
agent_group: 1
dependencies:
- "reference_1_log_prob"
- node_id: "critic_1_value"
node_type: "MODEL_TRAIN"
node_role: "CRITIC"
agent_group: 0
only_forward_compute: true
dependencies:
- "reference_2_log_prob"
- node_id: "critic_2_value"
node_type: "MODEL_TRAIN"
node_role: "CRITIC"
agent_group: 1
only_forward_compute: true
dependencies:
- "critic_1_value"
agent_options:
share_instance: 0
- node_id: "calculate_2_advantages"
node_type: "COMPUTE"
node_role: "ADVANTAGE"
agent_group: 1
dependencies:
- "critic_2_value"
- node_id: "critic_1_train"
node_type: "MODEL_TRAIN"
node_role: "CRITIC"
agent_group: 0
dependencies:
- "calculate_2_advantages"
- node_id: "critic_2_train"
node_type: "MODEL_TRAIN"
node_role: "CRITIC"
agent_group: 1
dependencies:
- "critic_1_train"
agent_options:
share_instance: 0
- node_id: "actor_1_train"
node_type: "MODEL_TRAIN"
node_role: "ACTOR"
agent_group: 0
config: *actor1_config
agent_options:
train_cycle: 15
dependencies:
- "critic_2_train"
- node_id: "actor_2_train"
node_type: "MODEL_TRAIN"
node_role: "ACTOR"
agent_group: 1
config: *actor2_config
agent_options:
train_cycle: 15
dependencies:
- "actor_1_train"
================================================
FILE: examples/experimental/marft/config/workflow_marft_code.yaml
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
dag_id: "marft_ppo_training_pipeline"
description: "This is MARFT DAG workflow configured via YAML."
actor_1_config: &actor1_config
rollout.log_prob_micro_batch_size_per_gpu: 16
rollout.tensor_model_parallel_size: 4
rollout.gpu_memory_utilization: 0.3
rollout.n: 1
actor_2_config: &actor2_config
rollout.log_prob_micro_batch_size_per_gpu: 16
rollout.tensor_model_parallel_size: 4
rollout.gpu_memory_utilization: 0.3
rollout.n: 1
nodes:
- node_id: "rollout_reasoner"
node_type: "MODEL_INFERENCE"
node_role: "ROLLOUT"
config: *actor1_config
agent_group: 0
dependencies: []
agent_options:
obs_with_env: true
process_path: examples/experimental/marft/config/process.py
pre_process_kwargs:
pre_chat_template: "Two LLM agents (Reasoner → Coder) collaborate to solve Codeforces Python coding problems.\nYou are the **Reasoner**: Analyze the problem statement, constraints, and expected behavior.\nIdentify edge cases, break the problem into logical steps, and suggest a high-level algorithmic plan.\nYou may include helpful pseudocode and edge case analysis, but do **not** write actual Python code.\n<|im_start|>problem: ${prompt}\n reasoner: "
post_process_kwargs:
post_chat_template: " reasoner: "
- node_id: "rollout_actor"
node_type: "MODEL_INFERENCE"
node_role: "ROLLOUT"
config: *actor2_config
agent_group: 1
dependencies:
- "rollout_reasoner"
agent_options:
obs_with_env: true
process_path: examples/experimental/marft/config/process.py
pre_process_kwargs:
pre_chat_template: "Two LLM agents (Reasoner → Coder) collaborate to solve Codeforces Python coding problems.\nYou are the **Coder**: Implement the Reasoner's plan using efficient and correct Python code.\nHandle edge cases, follow the provided strategy, and ensure clarity and correctness.\nAlways use Python.\nPlace your complete solution below the line starting with '```python```'.\n${prompt} coder: "
post_process_kwargs:
post_chat_template: " coder: "
env_path: [examples/experimental/marft/config/code_env.py:CodeEnv]
- node_id: "function_reward"
node_type: "COMPUTE"
node_role: "REWARD"
agent_group: 1
dependencies:
- "rollout_actor"
- node_id: "actor_1_old_log_prob"
node_type: "MODEL_TRAIN"
node_role: "ACTOR"
only_forward_compute: true
agent_group: 0
config: *actor1_config
dependencies:
- "function_reward"
- node_id: "actor_2_old_log_prob"
node_type: "MODEL_TRAIN"
node_role: "ACTOR"
only_forward_compute: true
agent_group: 1
config: *actor2_config
dependencies:
- "actor_1_old_log_prob"
- node_id: "reference_1_log_prob"
node_type: "MODEL_TRAIN"
node_role: "REFERENCE"
agent_group: 0
dependencies:
- "actor_2_old_log_prob"
- node_id: "reference_2_log_prob"
node_type: "MODEL_TRAIN"
node_role: "REFERENCE"
agent_group: 1
dependencies:
- "reference_1_log_prob"
- node_id: "critic_1_value"
node_type: "MODEL_TRAIN"
node_role: "CRITIC"
agent_group: 0
only_forward_compute: true
dependencies:
- "reference_2_log_prob"
- node_id: "critic_2_value"
node_type: "MODEL_TRAIN"
node_role: "CRITIC"
agent_group: 1
only_forward_compute: true
dependencies:
- "critic_1_value"
agent_options:
share_instance: 0
- node_id: "calculate_2_advantages"
node_type: "COMPUTE"
node_role: "ADVANTAGE"
agent_group: 1
dependencies:
- "critic_2_value"
- node_id: "critic_1_train"
node_type: "MODEL_TRAIN"
node_role: "CRITIC"
agent_group: 0
dependencies:
- "calculate_2_advantages"
- node_id: "critic_2_train"
node_type: "MODEL_TRAIN"
node_role: "CRITIC"
agent_group: 1
dependencies:
- "critic_1_train"
agent_options:
share_instance: 0
- node_id: "actor_1_train"
node_type: "MODEL_TRAIN"
node_role: "ACTOR"
agent_group: 0
config: *actor1_config
agent_options:
train_cycle: 15
dependencies:
- "critic_2_train"
- node_id: "actor_2_train"
node_type: "MODEL_TRAIN"
node_role: "ACTOR"
agent_group: 1
config: *actor2_config
agent_options:
train_cycle: 15
dependencies:
- "actor_1_train"
================================================
FILE: examples/experimental/marft/run_qwen2_5-3b_marft.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=gae_marft
export MODEL_NAME=qwen3-1.7b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-1.7B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=128
export PPO_MINI_BATCH_SIZE_PER_NODE=64
export PPO_MICRO_BATCH_SIZE_PER_GPU=4
export MAX_PROMPT_LENGTH=10240
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.3
export ROLLOUT_TP=4
export ROLLOUT_N=1
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
export PROJECT_DIR="$(pwd)"
export DAG_WORKERFLOW=$PROJECT_DIR/examples/experimental/marft/config/workflow_marft.yaml
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.return_raw_chat=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=sglang
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=12288
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.agent.rewards_with_env=True
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=3
actor_rollout_ref.rollout.multi_turn.use_all_traj=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
critic.optim.lr=1e-5
critic.model.use_remove_padding=True
critic.model.path=\$MODEL_PATH
critic.model.enable_gradient_checkpointing=True
critic.use_dynamic_bsz=False
critic.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
critic.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
critic.ppo_max_token_len_per_gpu=12288
critic.model.fsdp_config.param_offload=False
critic.model.fsdp_config.optimizer_offload=False
algorithm.kl_ctrl.kl_coef=0.001
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=False
dag.workflow_path=\$DAG_WORKERFLOW
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/experimental/multiturn_server/run_qwen2_5-3b_grpo_multiturn_vllm.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen2.5-3b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-1.7B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=256
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=1024
export MAX_RESPONSE_LENGTH=1024
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.4
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
export PROJECT_DIR="$(pwd)"
export CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config
export TOOL_CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config/tool_config/gsm8k_tool_config.yaml
export INTERACTION_CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config/interaction_config/gsm8k_interaction_config.yaml
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
--config-path=\$CONFIG_PATH
--config-name='gsm8k_multiturn_grpo'
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.return_raw_chat=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.mode=async
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
actor_rollout_ref.rollout.multi_turn.tool_config_path="\$TOOL_CONFIG_PATH"
actor_rollout_ref.rollout.multi_turn.interaction_config_path="\$INTERACTION_CONFIG_PATH"
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=1
actor_rollout_ref.rollout.multi_turn.max_user_turns=1
actor_rollout_ref.rollout.agent.agent_name="tool_agent"
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
if [ "$HOME" = "{your_home_path}" ] || [ -z "$HOME" ]; then echo "ERROR: Please set 'HOME' variable." >&2; exit 1; fi
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5-32b-metax.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler #deepscaler
export ALG=grpo
export MODEL_NAME=qwen2.5-32b
# --- Path Definitions ---
export HOME=/workspace/
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/models/Qwen2.5-32B-Instruct
# Base output paths
export BASE_CKPT_PATH=$HOME/siirl_ckpts2
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512 #1024
export PPO_MINI_BATCH_SIZE_PER_NODE=128
export PPO_MICRO_BATCH_SIZE_PER_GPU=4
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.45
export ROLLOUT_TP=4
export ROLLOUT_N=4
export SAVE_FREQ=-1
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# mx gpu env
export MACA_PATH=/opt/maca
export CUCC_PATH=${MACA_PATH}/tools/cu-bridge
export CUDA_PATH=${CUCC_PATH}
export MACA_CLANG_PATH=$MACA_PATH/mxgpu_llvm/bin
export PATH=${CUDA_PATH}/bin:${MACA_CLANG_PATH}:${PATH}
export LD_LIBRARY_PATH=${MACA_PATH}/tools/cu-bridge/lib/:${MACA_PATH}/lib:${MACA_PATH}/ompi/lib:${MACA_PATH}/mxgpu_llvm/lib:${LD_LIBRARY_PATH}
export PYTORCH_ENABLE_SAME_RAND_A100=1
export MACA_SMALL_PAGESIZE_ENABLE=1
export SET_DEVICE_NUMA_PREFERRED=1
# export CUDA_DEVICE_MAX_CONNECTIONS=1
export MCPYTORCH_DISABLE_PRINT=1
export MAX_JOBS=20
export VLLM_USE_V1=0
unset PYTORCH_CUDA_ALLOC_CONF
export MCCL_ENABLE_FC=0
# export MACA_PRIORITY_QUEUE_POLICY=0xa11
# export MCCL_PCIE_BUFFER_MODE=0
# export MCCL_NET_GDR_LEVEL=SYS
# export MCCL_USE_FILE_TUNING=0
# export MCCL_ALGO=Ring
# export MCCL_DISABLE_MULTI_NODE_FABRIC=1
# export MCCL_DISABLE_OPTIC_LINK_=1
export MCCL_MAX_NCHANNELS=8
export PYTHONUNBUFFERED=1
export MCCL_IB_HCA=mlx5
export MCCL_SOCKET_IFNAME=ens1
export GLOO_SOCKET_IFNAME=ens1
export SOCKET_NIC=ens1
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.auto_repeat=True
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$(($PPO_MICRO_BATCH_SIZE_PER_GPU / 1))
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=True
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
actor_rollout_ref.actor.fsdp_config.fsdp_size=64
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$(($PPO_MICRO_BATCH_SIZE_PER_GPU * 4))
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
# export VLLM_USE_V1=0
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
# local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
local ready_nodes=$(ray status | grep "node_" | wc -l)
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5-32b-npu.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen2.5-32b
export VLLM_USE_V1=1
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-32B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- GLOO Configuration ---
export GLOO_SOCKET_IFNAME=enp91s0np0
export HCCL_SOCKET_IFNAME=enp91s0np0
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export HCCL_CONNECT_TIMEOUT=7200
export GLOO_LOG_LEVEL=INFO
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=1024
export PPO_MINI_BATCH_SIZE_PER_NODE=128
export PPO_MICRO_BATCH_SIZE_PER_GPU=4
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.5
export ROLLOUT_TP=4
export ROLLOUT_N=5
export SAVE_FREQ=-1
export TEST_FREQ=5
export TOTAL_EPOCHS=300
export MAX_CKPT_KEEP=5
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-16}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.auto_repeat=True
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.3
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=True
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
actor_rollout_ref.actor.fsdp_config.fsdp_size=64
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
trainer.device=npu
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
echo "Cleaning up residual distributed processes..."
pkill -f ray || true
pkill -f siirl.main_dag || true
pkill -f torchrun || true
pkill -f vllm || true
pkill -f hccl || true
for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do
for pid in $(lsof -ti :$port); do
kill -9 $pid || true
done
done
sleep 3
echo "Cleanup finished."
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5-72b-npu.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen2.5-72b
export VLLM_USE_V1=1
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-32B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- GLOO Configuration ---
export GLOO_SOCKET_IFNAME=enp91s0np0
export HCCL_SOCKET_IFNAME=enp91s0np0
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export HCCL_CONNECT_TIMEOUT=7200
export GLOO_LOG_LEVEL=INFO
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=32
export PPO_MICRO_BATCH_SIZE_PER_GPU=2
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.5
export ROLLOUT_TP=8
export ROLLOUT_N=6
export SAVE_FREQ=-1
export TEST_FREQ=5
export TOTAL_EPOCHS=300
export MAX_CKPT_KEEP=5
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-16}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=yangdian_npu_scale_up
export EXPERIMENT_NAME=npu_siirl_${MODEL_NAME}_${NNODES}_nodes_${ALG}_${DATASET}_experiment_$(date +%Y%m%d_%H%M%S)
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.auto_repeat=True
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.3
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=True
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
actor_rollout_ref.actor.fsdp_config.fsdp_size=64
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
trainer.device=npu
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" ; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" ; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
echo "Cleaning up residual distributed processes..."
pkill -f ray || true
pkill -f siirl.main_dag || true
pkill -f torchrun || true
pkill -f vllm || true
pkill -f hccl || true
for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do
for pid in $(lsof -ti :$port); do
kill -9 $pid || true
done
done
sleep 3
echo "Cleanup finished."
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5-7b-npu-e2e_prof.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=gsm8k
export ALG=grpo
export MODEL_NAME=qwen2.5-7b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct
export PROFILE_PATH='./profile_data'
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=16
export PPO_MINI_BATCH_SIZE_PER_NODE=16
export PPO_MICRO_BATCH_SIZE_PER_GPU=1
export MAX_PROMPT_LENGTH=1024
export MAX_RESPONSE_LENGTH=128
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.5
export ROLLOUT_TP=2
export ROLLOUT_N=5
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=5e-8
actor_rollout_ref.model.use_remove_padding=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.kl_loss_coef=0.001
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
trainer.device=npu
profiler.enable=True
profiler.save_path=\$PROFILE_PATH
profiler.level='level1'
profiler.ranks=[0]
profiler.profile_steps=[3]
profiler.discrete=False
profiler.with_cpu=True
profiler.with_memory=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5-7b-npu-mindspeed.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen2.5-7b
export VLLM_USE_V1=1
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- GLOO Configuration ---
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export HCCL_CONNECT_TIMEOUT=7200
export GLOO_LOG_LEVEL=INFO
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=1024
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=4
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.5
export ROLLOUT_TP=4
export ROLLOUT_N=5
export ACTOR_REF_TP=4
export ACTOR_REF_PP=1
export ACTOR_REF_CP=1
export ACTOR_REF_SP=False
export SAVE_FREQ=-1
export TEST_FREQ=5
export TOTAL_EPOCHS=300
export MAX_CKPT_KEEP=5
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.auto_repeat=True
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.3
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.use_dist_checkpointing=False
actor_rollout_ref.actor.megatron.use_mbridge=False
+actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=True
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
actor_rollout_ref.actor.fsdp_config.fsdp_size=16
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
trainer.device=npu
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
echo "Cleaning up residual distributed processes..."
pkill -f ray || true
pkill -f siirl.main_dag || true
pkill -f torchrun || true
pkill -f vllm || true
pkill -f hccl || true
for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do
for pid in $(lsof -ti :$port); do
kill -9 $pid || true
done
done
sleep 3
echo "Cleanup finished."
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5-7b-npu.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen2.5-7b
export VLLM_USE_V1=1
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- GLOO Configuration ---
export GLOO_SOCKET_IFNAME=enp91s0np0
export HCCL_SOCKET_IFNAME=enp91s0np0
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export HCCL_CONNECT_TIMEOUT=7200
export GLOO_LOG_LEVEL=INFO
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=1024
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=4
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.5
export ROLLOUT_TP=4
export ROLLOUT_N=5
export SAVE_FREQ=-1
export TEST_FREQ=5
export TOTAL_EPOCHS=300
export MAX_CKPT_KEEP=5
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-16}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.auto_repeat=True
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.3
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=True
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
actor_rollout_ref.actor.fsdp_config.fsdp_size=16
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
trainer.device=npu
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
echo "Cleaning up residual distributed processes..."
pkill -f ray || true
pkill -f siirl.main_dag || true
pkill -f torchrun || true
pkill -f vllm || true
pkill -f hccl || true
for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do
for pid in $(lsof -ti :$port); do
kill -9 $pid || true
done
done
sleep 3
echo "Cleanup finished."
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5-7b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen2.5-7b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=False
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5_vl-72b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=mm_eureka
export ALG=grpo
export MODEL_NAME=qwen2.5-vl-72b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-VL-72B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=128
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=8
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.kl_ctrl.kl_coef=0.001
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.del_local_ckpt_after_load=False
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME:-bond0}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME:-bond0}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_IFNAME=bond0
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5_vl-7b-npu.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=geo3k
export ALG=grpo
export MODEL_NAME=qwen2.5-vl-7b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-VL-7B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=32
export PPO_MICRO_BATCH_SIZE_PER_GPU=2
export MAX_PROMPT_LENGTH=1024
export MAX_RESPONSE_LENGTH=2048
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.3
export ROLLOUT_TP=4
export ROLLOUT_N=5
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.image_key=images
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
trainer.device=npu
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen2_5_vl-7b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=mm_eureka
export ALG=grpo
export MODEL_NAME=qwen2.5-vl-7b
# --- Path Definitions ---
export TRAIN_DATA_PATH=/inspire/hdd/project/qianghuaxuexi/public/datasets/mm_eureka/train.parquet
export TEST_DATA_PATH=/inspire/hdd/project/qianghuaxuexi/public/datasets/mm_eureka/test.parquet
export MODEL_PATH=/inspire/hdd/project/qianghuaxuexi/public/models/Qwen2.5-VL-7B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen3-235b-megatron.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- For config debugging
export HYDRA_FULL_ERROR=0
export SIIRL_LOG_VERBOSITY=INFO
export RAY_DEDUP_LOGS=1
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen3-235b-a22b
# --- Path Definitions ---
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-235B-A22B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
export CUDA_DEVICE_MAX_CONNECTIONS=1
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=32
export PPO_MINI_BATCH_SIZE_PER_NODE=32
export PPO_MICRO_BATCH_SIZE_PER_GPU=4
export MAX_PROMPT_LENGTH=$((1024 * 2))
export MAX_RESPONSE_LENGTH=$((1024 * 8))
export MAX_MODEL_LENGTH=$((1024 * 10))
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.4
export ROLLOUT_TP=16
export ROLLOUT_N=16
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=15
export MAX_CKPT_KEEP=5
export ACTOR_REF_PP=8
# export ACTOR_REF_VPP=1
export ACTOR_REF_TP=1
export ACTOR_REF_EP=8
export ACTOR_REF_CP=1
export ACTOR_REF_SP=True
export use_dynamic_bsz=False
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_zp_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_moe_megatron_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=True
actor_rollout_ref.model.trust_remote_code=True
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.use_dynamic_bsz=\$use_dynamic_bsz
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=\$use_dynamic_bsz
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=\$use_dynamic_bsz
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
# actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=\$ACTOR_REF_VPP
actor_rollout_ref.actor.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.optimizer_offload=True
actor_rollout_ref.actor.megatron.use_dist_checkpointing=False
actor_rollout_ref.actor.megatron.use_mbridge=True
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.001
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=\$MAX_MODEL_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
# actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=\$ACTOR_REF_VPP
actor_rollout_ref.ref.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.ref.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.ref.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.ref.megatron.param_offload=True
actor_rollout_ref.ref.megatron.use_dist_checkpointing=False
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','wandb']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=off
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
dag.enable_perf=False
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen3-235b-npu-mindspeed.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen3-235b
export VLLM_USE_V1=1
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/models/Qwen3-235B-A22B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- GLOO Configuration ---
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export HCCL_CONNECT_TIMEOUT=7200
export HCCL_HOST_SOCKET_PORT_RANGE='auto'
export GLOO_LOG_LEVEL=INFO
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=64
export PPO_MINI_BATCH_SIZE_PER_NODE=32
export PPO_MICRO_BATCH_SIZE_PER_GPU=2
export MAX_PROMPT_LENGTH=1024
export MAX_RESPONSE_LENGTH=1024
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.7
export ROLLOUT_TP=32
export ROLLOUT_N=5
export ACTOR_REF_TP=4
export ACTOR_REF_EP=8
export ACTOR_REF_PP=16
export ACTOR_REF_VPP=2
export ACTOR_REF_CP=1
export ACTOR_REF_SP=True
export SAVE_FREQ=-1
export TEST_FREQ=5
export TOTAL_EPOCHS=300
export MAX_CKPT_KEEP=5
export RAY_DEDUP_LOGS=0
export HYDRA_FULL_ERROR=0
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-16}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=npu_${MODEL_NAME}_tp${ACTOR_REF_TP}pp${ACTOR_REF_PP}ep${ACTOR_REF_EP}_rtp${ROLLOUT_TP}_${NNODES}_nodes_${ALG}_${DATASET}_experiment_$(date +%Y%m%d_%H%M%S)
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.auto_repeat=True
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.model.trust_remote_code=True
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.3
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.actor.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.optimizer_offload=True
actor_rollout_ref.actor.megatron.grad_offload=True
actor_rollout_ref.actor.megatron.use_dist_checkpointing=False
actor_rollout_ref.actor.megatron.use_mbridge=True
+actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True
+actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type='alltoall'
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_expert_capacity_factor=1.5
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permutation_async_comm=True
+actor_rollout_ref.actor.megatron.override_transformer_config.sequence_parallel=True
+actor_rollout_ref.actor.megatron.override_transformer_config.use_fused_swiglu=True
actor_rollout_ref.ref.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.megatron.param_offload=True
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','wandb']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
trainer.device=npu
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
echo "Cleaning up residual distributed processes..."
pkill -f ray || true
pkill -f siirl.main_dag || true
pkill -f torchrun || true
pkill -f vllm || true
pkill -f hccl || true
for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do
for pid in $(lsof -ti :$port); do
kill -9 $pid || true
done
done
sleep 3
echo "Cleanup finished."
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen3-30b-npu-mindspeed.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen3-30b
export VLLM_USE_V1=1
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/models/Qwen3-30B-A3B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- GLOO Configuration ---
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export HCCL_CONNECT_TIMEOUT=7200
export HCCL_HOST_SOCKET_PORT_RANGE='auto'
export GLOO_LOG_LEVEL=INFO
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=64
export PPO_MINI_BATCH_SIZE_PER_NODE=32
export PPO_MICRO_BATCH_SIZE_PER_GPU=2
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.7
export ROLLOUT_TP=8
export ROLLOUT_N=5
export ACTOR_REF_TP=4
export ACTOR_REF_EP=8
export ACTOR_REF_PP=4
export ACTOR_REF_CP=1
export ACTOR_REF_SP=True
export SAVE_FREQ=-1
export TEST_FREQ=5
export TOTAL_EPOCHS=300
export MAX_CKPT_KEEP=5
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_tp${ACTOR_REF_TP}pp${ACTOR_REF_PP}ep${ACTOR_REF_EP}_rtp${ROLLOUT_TP}_${NNODES}_nodes_${ALG}_${DATASET}_experiment_$(date +%Y%m%d_%H%M%S)
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.auto_repeat=True
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.3
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.actor.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.use_dist_checkpointing=False
actor_rollout_ref.actor.megatron.use_mbridge=True
+actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
trainer.device=npu
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
echo "Cleaning up residual distributed processes..."
pkill -f ray || true
pkill -f siirl.main_dag || true
pkill -f torchrun || true
pkill -f vllm || true
pkill -f hccl || true
for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do
for pid in $(lsof -ti :$port); do
kill -9 $pid || true
done
done
sleep 3
echo "Cleanup finished."
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen3-8b-megatron.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- For debugging
export HYDRA_FULL_ERROR=0
export SIIRL_LOG_VERBOSITY=INFO
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen3-8b
# --- Path Definitions ---
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-8B
# Base output paths
export BASE_CKPT_PATH=$HOME/ckpts
export BASE_TENSORBOARD_PATH=$HOME/tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=128
export PPO_MINI_BATCH_SIZE_PER_NODE=16
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.45
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# ---- Key Parallelism Configuration ----
export ROLLOUT_TP=4
export ACTOR_REF_TP=4
export ACTOR_REF_PP=2
export ACTOR_REF_CP=1
export ACTOR_REF_SP=False
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.use_dist_checkpointing=False
actor_rollout_ref.actor.megatron.seed=1
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.ref.megatron.param_offload=False
trainer.logger=['console']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/grpo_trainer/run_qwen3-8b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen3-8b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-8B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.5
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/gspo_trainer/run_qwen3-1.7b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- For config debugging
export HYDRA_FULL_ERROR=0
export SIIRL_LOG_VERBOSITY=INFO
export RAY_DEDUP_LOGS=1
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=gspo
export MODEL_NAME=qwen3-1.7b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-1.7B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=$((1024 * 2))
export MAX_RESPONSE_LENGTH=$((1024 * 4))
export MAX_MODEL_LENGTH=$((1024 * 6))
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.5
export ROLLOUT_TP=1
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- GSPO Specific Parameters ---
export LOSS_MODE=gspo
export ADV_ESTIMATOR=grpo
export CLIP_RATIO_LOW=3e-4
export CLIP_RATIO_HIGH=4e-4
export CLIP_RATIO_C=10.0
export LOSS_AGG_MODE="token-mean"
# --- KL Configuration ---
export USE_KL_IN_REWARD=False
export KL_COEF=0.001
export USE_KL_LOSS=True
export KL_LOSS_COEF=0.01
export KL_LOSS_TYPE=low_var_kl
# --- FSDP Configuration for 1.7B ---
export FSDP_PARAM_OFFLOAD=False
export FSDP_OPTIMIZER_OFFLOAD=False
export REF_PARAM_OFFLOAD=True
# --- Sampling Parameters ---
export TEMPERATURE=1.0
export TOP_P=1.0
export TOP_K=-1
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed
# export GLOO_SOCKET_IFNAME=bond0
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ADV_ESTIMATOR
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.model.trust_remote_code=True
actor_rollout_ref.model.enable_gradient_checkpointing=True
# Actor strategy and GSPO configuration
actor_rollout_ref.actor.strategy=fsdp
actor_rollout_ref.actor.policy_loss.loss_mode=\$LOSS_MODE
actor_rollout_ref.actor.loss_agg_mode=\$LOSS_AGG_MODE
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_c=\$CLIP_RATIO_C
actor_rollout_ref.actor.use_kl_loss=\$USE_KL_LOSS
actor_rollout_ref.actor.kl_loss_coef=\$KL_LOSS_COEF
actor_rollout_ref.actor.kl_loss_type=\$KL_LOSS_TYPE
actor_rollout_ref.actor.policy_drift_coeff=0.001
# PPO configuration
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.entropy_coeff=0
# FSDP configuration for actor
actor_rollout_ref.actor.fsdp_config.param_offload=\$FSDP_PARAM_OFFLOAD
actor_rollout_ref.actor.fsdp_config.optimizer_offload=\$FSDP_OPTIMIZER_OFFLOAD
# Rollout configuration
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=\$MAX_MODEL_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.temperature=\$TEMPERATURE
actor_rollout_ref.rollout.top_p=\$TOP_P
actor_rollout_ref.rollout.top_k=\$TOP_K
actor_rollout_ref.rollout.calculate_log_probs=True
# Reference model configuration
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=\$REF_PARAM_OFFLOAD
# Algorithm configuration
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.use_kl_in_reward=\$USE_KL_IN_REWARD
algorithm.kl_ctrl.kl_coef=\$KL_COEF
# Trainer configuration
trainer.critic_warmup=0
trainer.logger='["console","tensorboard","wandb"]'
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
dag.enable_perf=False
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting GSPO training command."
echo "Command: ${TRAINING_CMD[*]}"
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
main "$@"
fi
================================================
FILE: examples/gspo_trainer/run_qwen3-235b-megatron.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- For config debugging
export HYDRA_FULL_ERROR=0
export SIIRL_LOG_VERBOSITY=INFO
export RAY_DEDUP_LOGS=1
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=gspo
export MODEL_NAME=qwen3-235b-a22b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-235B-A22B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
export CUDA_DEVICE_MAX_CONNECTIONS=1
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=32 # Conservative for 235B
export PPO_MINI_BATCH_SIZE_PER_NODE=32
export PPO_MICRO_BATCH_SIZE_PER_GPU=4
export MAX_PROMPT_LENGTH=$((1024 * 2))
export MAX_RESPONSE_LENGTH=$((1024 * 8))
export MAX_MODEL_LENGTH=$((1024 * 10))
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.4 # Conservative for 235B
export ROLLOUT_TP=16 # High TP for 235B
export ROLLOUT_N=16
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=15
export MAX_CKPT_KEEP=5
# --- GSPO Specific Parameters ---
export LOSS_MODE=gspo
export ADV_ESTIMATOR=grpo
export CLIP_RATIO_LOW=3e-4
export CLIP_RATIO_HIGH=4e-4
export CLIP_RATIO_C=10.0
export LOSS_AGG_MODE="token-mean"
# --- KL Configuration ---
export USE_KL_IN_REWARD=False
export KL_COEF=0.001
export USE_KL_LOSS=True
export KL_LOSS_COEF=0.001
export KL_LOSS_TYPE=low_var_kl
# --- Megatron Parallelism for 235B ---
export ACTOR_REF_PP=8 # High pipeline parallel for 235B
export ACTOR_REF_TP=1 # Low tensor parallel
export ACTOR_REF_EP=8 # High expert parallel for MoE
export ACTOR_REF_CP=1 # Context parallel
export ACTOR_REF_SP=True # Sequence parallel
export use_dynamic_bsz=False
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed
# export GLOO_SOCKET_IFNAME=bond0
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_zp_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_moe_megatron_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ADV_ESTIMATOR
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=True
actor_rollout_ref.model.trust_remote_code=True
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.use_dynamic_bsz=\$use_dynamic_bsz
# GSPO specific loss configuration
actor_rollout_ref.actor.policy_loss.loss_mode=\$LOSS_MODE
actor_rollout_ref.actor.loss_agg_mode=\$LOSS_AGG_MODE
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_c=\$CLIP_RATIO_C
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=\$use_dynamic_bsz
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=\$use_dynamic_bsz
# Megatron configuration for actor (235B)
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.actor.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.optimizer_offload=True
actor_rollout_ref.actor.megatron.use_dist_checkpointing=False
actor_rollout_ref.actor.megatron.use_mbridge=True
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True
# PPO configuration
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=\$USE_KL_LOSS
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=\$KL_LOSS_COEF
actor_rollout_ref.actor.kl_loss_type=\$KL_LOSS_TYPE
# Rollout configuration (235B)
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=\$MAX_MODEL_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.rollout.n=\$ROLLOUT_N
# Reference model configuration (235B)
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.ref.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.ref.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.ref.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.ref.megatron.param_offload=True
actor_rollout_ref.ref.megatron.use_dist_checkpointing=False
# Algorithm configuration
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.use_kl_in_reward=\$USE_KL_IN_REWARD
algorithm.kl_ctrl.kl_coef=\$KL_COEF
# Trainer configuration
trainer.critic_warmup=0
trainer.logger='["console","tensorboard"]'
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=off
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
dag.enable_perf=False
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting GSPO training command."
echo "Command: ${TRAINING_CMD[*]}"
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
main "$@"
fi
================================================
FILE: examples/gspo_trainer/run_qwen3-30b-gspo-megatron.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- For config debugging
export HYDRA_FULL_ERROR=0
export SIIRL_LOG_VERBOSITY=INFO
export RAY_DEDUP_LOGS=1
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=gspo
export MODEL_NAME=qwen3-30b-a3b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-30B-A3B-Base
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
export CUDA_DEVICE_MAX_CONNECTIONS=1
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=64 # Increased for 30B
export PPO_MINI_BATCH_SIZE_PER_NODE=32
export PPO_MICRO_BATCH_SIZE_PER_GPU=2
export MAX_PROMPT_LENGTH=$((1024 * 2))
export MAX_RESPONSE_LENGTH=$((1024 * 8))
export MAX_MODEL_LENGTH=$((1024 * 10))
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6 # Higher for 30B
export ROLLOUT_TP=4 # Reduced for 30B
export ROLLOUT_N=16
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=15
export MAX_CKPT_KEEP=5
# --- GSPO Specific Parameters ---
export LOSS_MODE=gspo
export ADV_ESTIMATOR=grpo
export CLIP_RATIO_LOW=3e-4
export CLIP_RATIO_HIGH=4e-4
export CLIP_RATIO_C=10.0
export LOSS_AGG_MODE="token-mean"
# --- KL Configuration ---
export USE_KL_IN_REWARD=False
export KL_COEF=0.001
export USE_KL_LOSS=True
export KL_LOSS_COEF=0.001
export KL_LOSS_TYPE=low_var_kl
# --- Megatron Parallelism for 30B ---
export ACTOR_REF_PP=2 # Reduced pipeline parallel
export ACTOR_REF_TP=4 # Tensor parallel
export ACTOR_REF_EP=1 # No expert parallel for 30B
export ACTOR_REF_CP=1 # Context parallel
export ACTOR_REF_SP=True # Sequence parallel
export use_dynamic_bsz=False
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed
# export GLOO_SOCKET_IFNAME=bond0
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ADV_ESTIMATOR
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=True
actor_rollout_ref.model.trust_remote_code=True
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.use_dynamic_bsz=\$use_dynamic_bsz
# GSPO specific loss configuration
actor_rollout_ref.actor.policy_loss.loss_mode=\$LOSS_MODE
actor_rollout_ref.actor.loss_agg_mode=\$LOSS_AGG_MODE
actor_rollout_ref.actor.clip_ratio_low=\$CLIP_RATIO_LOW
actor_rollout_ref.actor.clip_ratio_high=\$CLIP_RATIO_HIGH
actor_rollout_ref.actor.clip_ratio_c=\$CLIP_RATIO_C
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=\$use_dynamic_bsz
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=\$use_dynamic_bsz
# Megatron configuration for actor
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.actor.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.optimizer_offload=True
actor_rollout_ref.actor.megatron.use_dist_checkpointing=False
actor_rollout_ref.actor.megatron.use_mbridge=True
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
# PPO configuration
actor_rollout_ref.actor.policy_drift_coeff=0.001
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=\$USE_KL_LOSS
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=\$KL_LOSS_COEF
actor_rollout_ref.actor.kl_loss_type=\$KL_LOSS_TYPE
# Rollout configuration
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=\$MAX_MODEL_LENGTH
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.rollout.n=\$ROLLOUT_N
# Reference model configuration
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\$ACTOR_REF_TP
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\$ACTOR_REF_PP
actor_rollout_ref.ref.megatron.expert_model_parallel_size=\$ACTOR_REF_EP
actor_rollout_ref.ref.megatron.context_parallel_size=\$ACTOR_REF_CP
actor_rollout_ref.ref.megatron.sequence_parallel=\$ACTOR_REF_SP
actor_rollout_ref.ref.megatron.param_offload=True
actor_rollout_ref.ref.megatron.use_dist_checkpointing=False
# Algorithm configuration
algorithm.weight_factor_in_cpgd='STD_weight'
algorithm.use_kl_in_reward=\$USE_KL_IN_REWARD
algorithm.kl_ctrl.kl_coef=\$KL_COEF
# Trainer configuration
trainer.critic_warmup=0
trainer.logger='["console","tensorboard"]'
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=off
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
dag.enable_perf=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting GSPO training command."
echo "Command: ${TRAINING_CMD[*]}"
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
main "$@"
fi
================================================
FILE: examples/multi_turn/config/interaction_config/gsm8k_interaction_config.yaml
================================================
interaction:
- name: "gsm8k"
class_name: "siirl.execution.rollout_flow.multiturn.interactions.gsm8k_interaction.Gsm8kInteraction"
config: {}
================================================
FILE: examples/multi_turn/config/tool_config/gsm8k_tool_config.yaml
================================================
tools:
- class_name: "siirl.execution.rollout_flow.multiturn.tools.gsm8k_tool.Gsm8kTool"
config:
type: native
tool_schema:
type: "function"
function:
name: "calc_gsm8k_reward"
description: "A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)"
parameters:
type: "object"
properties:
answer:
type: "string"
description: "The model's answer to the GSM8K math problem, must be a digits"
required: ["answer"]
================================================
FILE: examples/multi_turn/gsm8k/run_qwen2_5-3b_grpo_multiturn_sglang.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=grpo
export MODEL_NAME=qwen2.5-3b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-3B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=256
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=1024
export MAX_RESPONSE_LENGTH=1024
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.4
export ROLLOUT_TP=2
export ROLLOUT_N=8
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
export PROJECT_DIR="$(pwd)"
export CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config
export TOOL_CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config/tool_config/gsm8k_tool_config.yaml
export INTERACTION_CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config/interaction_config/gsm8k_interaction_config.yaml
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.return_raw_chat=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=sglang
actor_rollout_ref.rollout.mode=sync
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
actor_rollout_ref.rollout.multi_turn.enable=True
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5
actor_rollout_ref.rollout.multi_turn.tool_config_path="\$TOOL_CONFIG_PATH"
actor_rollout_ref.rollout.multi_turn.interaction_config_path="\$INTERACTION_CONFIG_PATH"
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
if [ "$HOME" = "{your_home_path}" ] || [ -z "$HOME" ]; then echo "ERROR: Please set 'HOME' variable." >&2; exit 1; fi
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/ppo_trainer/run_qwen2_5-72b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=gae
export MODEL_NAME=qwen2.5-72b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen2.5-72B-Instruct
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=128
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=8
export ROLLOUT_N=1
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1
actor_rollout_ref.ref.fsdp_config.param_offload=True
critic.optim.lr=1e-5
critic.model.use_remove_padding=True
critic.model.path=\$MODEL_PATH
critic.model.enable_gradient_checkpointing=True
critic.use_dynamic_bsz=False
critic.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
critic.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
critic.ppo_max_token_len_per_gpu=98304
critic.model.fsdp_config.param_offload=False
critic.model.fsdp_config.optimizer_offload=False
algorithm.kl_ctrl.kl_coef=0.001
algorithm.use_kl_in_reward=False
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.del_local_ckpt_after_load=False
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}" --block
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_IFNAME=bond0
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 2
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/ppo_trainer/run_qwen3-8b-megatron.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- For debugging
export HYDRA_FULL_ERROR=0
export SIIRL_LOG_VERBOSITY=INFO
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=gae
export MODEL_NAME=qwen3-8b
# --- Path Definitions ---
export TRAIN_DATA_PATH=$HOME/data/dataset/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/dataset/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-8B
# Base output paths
export BASE_CKPT_PATH=$HOME/ckpts
export BASE_TENSORBOARD_PATH=$HOME/tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=1024
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.45
export ROLLOUT_N=1
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
export ACTOR_REF_CRITIC_TP=2
export ACTOR_REF_CRITIC_PP=2
export ACTOR_REF_CRITIC_CP=1
export ACTOR_REF_CRITIC_SP=False
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.strategy=megatron
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\$ACTOR_REF_CRITIC_TP
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\$ACTOR_REF_CRITIC_PP
actor_rollout_ref.actor.megatron.context_parallel_size=\$ACTOR_REF_CRITIC_CP
actor_rollout_ref.actor.megatron.sequence_parallel=\$ACTOR_REF_CRITIC_SP
actor_rollout_ref.actor.megatron.use_distributed_optimizer=True
actor_rollout_ref.actor.megatron.param_dtype=bfloat16
actor_rollout_ref.actor.megatron.param_offload=True
actor_rollout_ref.actor.megatron.use_dist_checkpointing=False
actor_rollout_ref.actor.megatron.seed=1
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ACTOR_REF_CRITIC_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=True
actor_rollout_ref.rollout.free_cache_engine=True
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.strategy=megatron
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\$ACTOR_REF_CRITIC_TP
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\$ACTOR_REF_CRITIC_PP
actor_rollout_ref.ref.megatron.context_parallel_size=\$ACTOR_REF_CRITIC_CP
actor_rollout_ref.ref.megatron.sequence_parallel=\$ACTOR_REF_CRITIC_SP
actor_rollout_ref.ref.megatron.param_offload=False
critic.optim.lr=1e-5
critic.model.use_remove_padding=True
critic.model.path=\$MODEL_PATH
critic.model.enable_gradient_checkpointing=True
critic.use_dynamic_bsz=False
critic.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
critic.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
critic.ppo_max_token_len_per_gpu=98304
critic.strategy=megatron
critic.megatron.tensor_model_parallel_size=\$ACTOR_REF_CRITIC_TP
critic.megatron.pipeline_model_parallel_size=\$ACTOR_REF_CRITIC_PP
critic.megatron.context_parallel_size=\$ACTOR_REF_CRITIC_CP
critic.megatron.sequence_parallel=\$ACTOR_REF_CRITIC_SP
critic.megatron.param_offload=True
critic.megatron.optimizer_offload=True
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}"
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 5
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: examples/ppo_trainer/run_qwen3-8b.sh
================================================
#!/usr/bin/env bash
# ===================================================================================
# === USER CONFIGURATION SECTION ===
# ===================================================================================
# --- Experiment and Model Definition ---
export DATASET=deepscaler
export ALG=gae
export MODEL_NAME=qwen3-8b
# --- Path Definitions ---
export HOME={your_home_path}
export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet
export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet
export MODEL_PATH=$HOME/data/models/Qwen3-8B
# Base output paths
export BASE_CKPT_PATH=ckpts
export BASE_TENSORBOARD_PATH=tensorboard
# --- Key Training Hyperparameters ---
export TRAIN_BATCH_SIZE_PER_NODE=512
export PPO_MINI_BATCH_SIZE_PER_NODE=256
export PPO_MICRO_BATCH_SIZE_PER_GPU=8
export MAX_PROMPT_LENGTH=2048
export MAX_RESPONSE_LENGTH=4096
export ROLLOUT_GPU_MEMORY_UTILIZATION=0.6
export ROLLOUT_TP=1
export ROLLOUT_N=1
export SAVE_FREQ=30
export TEST_FREQ=10
export TOTAL_EPOCHS=30
export MAX_CKPT_KEEP=5
# --- Multi-node (Multi-machine) distributed training environments ---
# Uncomment the following line and set the correct network interface if needed for distributed backend
# export GLOO_SOCKET_IFNAME=bond0 # Modify as needed
# --- Distributed Training & Infrastructure ---
export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}
export NNODES=${PET_NNODES:-1}
export NODE_RANK=${PET_NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-localhost}
# --- Output Paths and Experiment Naming ---
export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes
export PROJECT_NAME=siirl_${DATASET}_${ALG}
export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment
export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp
export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp
# --- Calculated Global Hyperparameters ---
export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))
export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))
# --- Define the Training Command and its Arguments ---
TRAINING_CMD=(
python3 -m siirl.main_dag
algorithm.adv_estimator=\$ALG
data.train_files=\$TRAIN_DATA_PATH
data.val_files=\$TEST_DATA_PATH
data.train_batch_size=\$TRAIN_BATCH_SIZE
data.max_prompt_length=\$MAX_PROMPT_LENGTH
data.max_response_length=\$MAX_RESPONSE_LENGTH
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
actor_rollout_ref.model.path=\$MODEL_PATH
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.model.use_fused_kernels=False
actor_rollout_ref.actor.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.grad_clip=0.5
actor_rollout_ref.actor.clip_ratio=0.2
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.tensor_model_parallel_size=\$ROLLOUT_TP
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.gpu_memory_utilization=\$ROLLOUT_GPU_MEMORY_UTILIZATION
actor_rollout_ref.rollout.max_model_len=8192
actor_rollout_ref.rollout.enable_chunked_prefill=False
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.free_cache_engine=False
actor_rollout_ref.rollout.n=\$ROLLOUT_N
actor_rollout_ref.rollout.prompt_length=\$MAX_PROMPT_LENGTH
actor_rollout_ref.rollout.response_length=\$MAX_RESPONSE_LENGTH
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.ref.fsdp_config.param_offload=True
critic.optim.lr=1e-5
critic.model.use_remove_padding=True
critic.model.path=\$MODEL_PATH
critic.model.enable_gradient_checkpointing=True
critic.use_dynamic_bsz=False
critic.ppo_micro_batch_size_per_gpu=\$PPO_MICRO_BATCH_SIZE_PER_GPU
critic.ppo_mini_batch_size=\$PPO_MINI_BATCH_SIZE
critic.ppo_max_token_len_per_gpu=98304
critic.model.fsdp_config.param_offload=False
critic.model.fsdp_config.optimizer_offload=False
algorithm.kl_ctrl.kl_coef=0.001
trainer.critic_warmup=0
trainer.logger=['console','tensorboard']
trainer.project_name=\$PROJECT_NAME
trainer.experiment_name=\$EXPERIMENT_NAME
trainer.n_gpus_per_node=\$N_GPUS_PER_NODE
trainer.nnodes=\$NNODES
trainer.save_freq=\$SAVE_FREQ
trainer.test_freq=\$TEST_FREQ
trainer.total_epochs=\$TOTAL_EPOCHS
trainer.resume_mode=auto
trainer.max_actor_ckpt_to_keep=\$MAX_CKPT_KEEP
trainer.default_local_dir=\$CKPT_PATH
trainer.val_before_train=True
)
# ===================================================================================
# === MAIN EXECUTION LOGIC & INFRASTRUCTURE ===
# ===================================================================================
# --- Boilerplate Setup ---
set -e
set -o pipefail
set -x
# --- Infrastructure & Boilerplate Functions ---
start_ray_cluster() {
local RAY_HEAD_WAIT_TIMEOUT=600
export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}
export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200
export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120
local ray_start_common_opts=(
--num-gpus "$N_GPUS_PER_NODE"
--object-store-memory 100000000000
--memory 100000000000
)
if [ "$NNODES" -gt 1 ]; then
if [ "$NODE_RANK" = "0" ]; then
echo "INFO: Starting Ray head node on $(hostname)..."
export RAY_ADDRESS="$RAY_MASTER_ADDR:$RAY_MASTER_PORT"
ray start --head --port="$RAY_MASTER_PORT" --dashboard-port="$RAY_DASHBOARD_PORT" "${ray_start_common_opts[@]}" --system-config='{"gcs_server_request_timeout_seconds": 60, "gcs_rpc_server_reconnect_timeout_s": 60}'
local start_time=$(date +%s)
while ! ray health-check --address "$RAY_ADDRESS" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head node. Exiting." >&2; ray stop --force; exit 1; fi
echo "Head node not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head node is healthy."
else
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO: Worker node $(hostname) waiting for head at $head_node_address..."
local start_time=$(date +%s)
while ! ray health-check --address "$head_node_address" &>/dev/null; do
if [ "$(( $(date +%s) - start_time ))" -ge "$RAY_HEAD_WAIT_TIMEOUT" ]; then echo "ERROR: Timed out waiting for head. Exiting." >&2; exit 1; fi
echo "Head not healthy yet. Retrying in 5s..."
sleep 5
done
echo "INFO: Head is healthy. Worker starting..."
ray start --address="$head_node_address" "${ray_start_common_opts[@]}" --block
fi
else
echo "INFO: Starting Ray in single-node mode..."
ray start --head "${ray_start_common_opts[@]}"
fi
}
# --- Main Execution Function ---
main() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
ray stop --force
export VLLM_USE_V1=1
export GLOO_SOCKET_TIMEOUT=600
export GLOO_TCP_TIMEOUT=600
export GLOO_LOG_LEVEL=DEBUG
export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}
export RAY_MASTER_ADDR=$MASTER_ADDR
start_ray_cluster
if [ "$NNODES" -gt 1 ] && [ "$NODE_RANK" = "0" ]; then
echo "Waiting for all $NNODES nodes to join..."
local TIMEOUT=600; local start_time=$(date +%s)
while true; do
if [ "$(( $(date +%s) - start_time ))" -ge "$TIMEOUT" ]; then echo "Error: Timeout waiting for nodes." >&2; exit 1; fi
local ready_nodes=$(ray list nodes --format=json | python3 -c "import sys, json; print(len(json.load(sys.stdin)))")
if [ "$ready_nodes" -ge "$NNODES" ]; then break; fi
echo "Waiting... ($ready_nodes / $NNODES nodes ready)"
sleep 2
done
echo "All $NNODES nodes have joined."
fi
if [ "$NODE_RANK" = "0" ]; then
echo "INFO [RANK 0]: Starting main training command."
eval "${TRAINING_CMD[@]}" "$@"
echo "INFO [RANK 0]: Training finished."
sleep 30; ray stop --force >/dev/null 2>&1
elif [ "$NNODES" -gt 1 ]; then
local head_node_address="$MASTER_ADDR:$RAY_MASTER_PORT"
echo "INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address."
while ray health-check --address "$head_node_address" &>/dev/null; do sleep 15; done
echo "INFO [RANK $NODE_RANK]: Head node down. Exiting."
fi
echo "INFO: Script finished on rank $NODE_RANK."
}
# --- Script Entrypoint ---
main "$@"
================================================
FILE: pyproject.toml
================================================
# ===================================================================
# pyproject.toml for siirl
#
# PEP 621-compliant configuration file for project metadata,
# build system, and tool configurations. This file works in
# conjunction with a minimal setup.py shim.
# ===================================================================
# -------------------------------
# Build System
# -------------------------------
[build-system]
requires = [
"setuptools>=61.0",
"setuptools_scm[toml]>=6.2",
"wheel"
]
build-backend = "setuptools.build_meta"
# -------------------------------
# Project Metadata (PEP 621)
# -------------------------------
[project]
name = "siirl"
# Version is loaded dynamically from a file. See [tool.setuptools.dynamic].
dynamic = ["version"]
description = "siirl: A Decentralized Multi-Agent Reinforcement Learning Framework"
license = {file = "LICENSE"}
readme = {file = "README.md", content-type = "text/markdown"}
requires-python = ">=3.8"
# --- Author & URL Information ---
authors = [
{ name = "Shanghai Innovation Institute - AI Infra Team", email = "llm19900326@gmail.com" },
]
# --- Project Discovery ---
keywords = ["reinforcement learning", "multi-agent", "decentralized", "rl", "ai"]
# Standardized classifiers from https://pypi.org/classifiers/
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
# --- Dependencies ---
# Runtime dependencies required by the project.
dependencies = [
"accelerate",
"codetiming",
"datasets>=4.0.0",
"dill",
"hydra-core",
"numpy",
"pandas",
"peft",
"pyarrow>=19.0.0",
"pybind11",
"pylatexenc",
"ray[default]>=2.47.1",
"torchdata",
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
"wandb",
"tensorboard",
"mathruler",
"math_verify",
"timm",
"imageio",
"loguru",
"packaging>=20.0",
"dacite",
"qwen_vl_utils",
"scipy",
"fastapi",
"transformers",
"math-verify",
"vllm>=0.8.5.post1",
]
# --- Optional Dependencies ---
# Corresponds to 'extras_require' in setup.py.
# Install with: pip install "siirl[gpu]"
[project.optional-dependencies]
# For core development and releasing
dev = [
"ruff",
"pytest",
"build",
"twine",
"pre-commit",
"py-spy",
]
test = [
"pytest",
"pre-commit",
"py-spy",
"pyext",
]
geo = ["mathruler"]
gpu = ["liger-kernel", "flash-attn"]
sglang = [
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
"sglang[all]>=0.4.6.post5",
"torch-memory-saver>=0.0.5",
"torch>=2.6.0",
]
# --- Project URLs ---
# This table should only contain string key-value pairs for URLs.
[project.urls]
"Homepage" = "https://github.com/sii-research/siiRL"
"Bug Tracker" = "https://github.com/sii-research/siiRL/issues"
"Repository" = "https://github.com/sii-research/siiRL"
# -------------------------------
# Tool: Ruff (Linting)
# -------------------------------
[tool.ruff]
line-length = 120 # TODO: Reduce this to a more reasonable value
[tool.ruff.lint]
isort = {known-first-party = ["siirl"]}
select = [ "E", "F", "UP", "B", "I", "G" ]
ignore = [ "F405", "F403", "E731", "B007", "UP032", "UP007", "G004" ]
# -------------------------------
# Tool: Setuptools
# -------------------------------
[tool.setuptools]
include-package-data = true
# Modern equivalent of find_packages()
packages = { find = {} }
[tool.setuptools_scm]
write_to = "siirl/_version.py"
[tool.setuptools.package-dir]
"" = "."
[tool.setuptools.package-data]
siirl = [
"client/config/*.yaml"
]
================================================
FILE: requirements-npu.txt
================================================
accelerate
codetiming
datasets>=4.0.0
dill
hydra-core
numpy
pandas
peft
pyarrow>=19.0.0
pybind11
pylatexenc
ray[default]>=2.47.1
torchdata
tensordict>=0.8.0,<=0.9.1,!=0.9.0,
transformers
wandb
tensorboard
mathruler
math_verify
timm
imageio
loguru
packaging>=20.0
dacite
qwen_vl_utils
scipy
fastapi
torch_npu==2.5.1
vllm>=0.9.1
vllm_ascend>=0.9.1
mbridge==0.13.0
================================================
FILE: requirements.txt
================================================
accelerate
codetiming
datasets>=4.0.0
dill
hydra-core
numpy
pandas
peft
pyarrow>=19.0.0
pybind11
pylatexenc
ray[default]>=2.47.1
torchdata
tensordict>=0.8.0,<=0.9.1,!=0.9.0,
transformers
wandb
tensorboard
mathruler
math_verify
timm
imageio
loguru
packaging>=20.0
dacite
qwen_vl_utils
scipy
fastapi
vllm>=0.8.5.post1
================================================
FILE: setup.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup
# This is a "shim" setup.py file that delegates all configuration
# to the pyproject.toml file. This is the recommended approach for
# projects that need to maintain a setup.py for compatibility while
# adopting modern packaging standards.
#
# All metadata, dependencies, and package data are defined in pyproject.toml.
# This setup() call is intentionally left empty.
setup()
================================================
FILE: siirl/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from importlib.metadata import version
from packaging.version import parse as parse_version
from importlib.metadata import PackageNotFoundError
from siirl.utils.extras.device import is_npu_available
from siirl.utils.logger.logging_utils import set_basic_config
set_basic_config()
__all__ = []
if os.getenv("SIIRL_USE_MODELSCOPE", "False").lower() == "true":
import importlib
if importlib.util.find_spec("modelscope") is None:
raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`")
# Patch hub to download models from modelscope to speed up.
from modelscope.utils.hf_util import patch_hub
patch_hub()
if is_npu_available:
from .models.transformers import npu_patch as npu_patch
package_name = "transformers"
required_version_spec = "4.52.4"
try:
installed_version = version(package_name)
installed = parse_version(installed_version)
required = parse_version(required_version_spec)
if not installed >= required:
raise ValueError(f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is {installed}.")
except PackageNotFoundError:
raise ImportError(f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}")
================================================
FILE: siirl/dag_worker/__init__.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: siirl/dag_worker/checkpoint_manager.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint save/load operations for distributed training."""
import os
import torch
import torch.distributed as dist
from typing import Dict, Optional, Any
from loguru import logger
from siirl.execution.dag.node import NodeRole, NodeType
from siirl.params import SiiRLArguments
from siirl.dag_worker.constants import DAGConstants
from siirl.dag_worker.dag_utils import generate_node_worker_key
from siirl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
class CheckpointManager:
"""Manages distributed checkpoint save/load with atomic commits."""
def __init__(
self,
config: SiiRLArguments,
rank: int,
gather_group: dist.ProcessGroup,
workers: Dict[str, Any],
taskgraph: Any,
dataloader: Any,
first_rollout_node: Any,
get_node_dp_info_fn: callable
):
self.config = config
self.rank = rank
self.gather_group = gather_group
self.workers = workers
self.taskgraph = taskgraph
self.dataloader = dataloader
self.first_rollout_node = first_rollout_node
self._get_node_dp_info = get_node_dp_info_fn
def save_checkpoint(self, global_steps: int) -> None:
"""Save checkpoint atomically across all ranks."""
step_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{global_steps}")
os.makedirs(step_dir, exist_ok=True)
dist.barrier(self.gather_group)
logger.info(f"Rank {self.rank}: Saving checkpoint for global_step {global_steps} to {step_dir}")
self._save_model_states(global_steps, step_dir)
self._save_dataloader_state(step_dir)
logger.debug(f"Rank {self.rank}: All data saved. Waiting at barrier before committing checkpoint.")
dist.barrier(self.gather_group)
if self.rank == 0:
self._commit_checkpoint(global_steps)
dist.barrier(self.gather_group)
logger.info(f"Rank {self.rank}: Finished saving and committing checkpoint for step {global_steps}.")
def _save_model_states(self, global_steps: int, step_dir: str) -> None:
"""Save model states for all trainable nodes."""
saved_worker_keys = set()
for node in self.taskgraph.nodes.values():
if node.node_type != NodeType.MODEL_TRAIN:
continue
if node.node_role not in [NodeRole.ACTOR, NodeRole.CRITIC]:
continue
node_worker_key = generate_node_worker_key(node)
if node_worker_key in saved_worker_keys:
continue
worker = self.workers[node_worker_key]
sub_dir_name = f"{node.node_role.name.lower()}_agent_{node.agent_group}"
checkpoint_path = os.path.join(step_dir, sub_dir_name)
role_name_for_config = node.node_role.name.lower()
max_ckpt_keep = getattr(
self.config.trainer,
f"max_{role_name_for_config}_ckpt_to_keep",
10
)
worker.save_checkpoint(
local_path=checkpoint_path,
global_step=global_steps,
max_ckpt_to_keep=max_ckpt_keep
)
saved_worker_keys.add(node_worker_key)
logger.debug(
f"Rank {self.rank}: Saved {node.node_role.name} checkpoint for agent {node.agent_group} "
f"to {checkpoint_path}"
)
def _save_dataloader_state(self, step_dir: str) -> None:
"""Save dataloader state (only TP rank 0 and PP rank 0 per DP group)."""
_, dp_rank, tp_rank, _, pp_rank, _ = self._get_node_dp_info(self.first_rollout_node)
if tp_rank == 0 and pp_rank == 0:
dataloader_path = os.path.join(step_dir, f"data_dp_rank_{dp_rank}.pt")
dataloader_state = self.dataloader.state_dict()
torch.save(dataloader_state, dataloader_path)
logger.debug(
f"Rank {self.rank} (DP_Rank {dp_rank}, TP_Rank {tp_rank}, PP_Rank {pp_rank}): "
f"Saved dataloader state to {dataloader_path}"
)
def _commit_checkpoint(self, global_steps: int) -> None:
"""Atomically commit checkpoint by writing tracker file (rank 0 only)."""
tracker_file = os.path.join(
self.config.trainer.default_local_dir,
"latest_checkpointed_iteration.txt"
)
with open(tracker_file, "w") as f:
f.write(str(global_steps))
logger.info(f"Rank 0: Checkpoint for step {global_steps} successfully committed.")
def load_checkpoint(self) -> int:
"""Load checkpoint and return global step to resume from."""
if self.config.trainer.resume_mode == "disable":
if self.rank == 0:
logger.info("Checkpoint loading is disabled. Starting from scratch.")
return 0
checkpoint_path = self._determine_checkpoint_path()
checkpoint_path_container = [checkpoint_path]
dist.broadcast_object_list(checkpoint_path_container, src=0)
global_step_folder = checkpoint_path_container[0]
if global_step_folder is None:
if self.rank == 0:
logger.info("No valid checkpoint to load. Training will start from step 0.")
dist.barrier(self.gather_group)
return 0
try:
global_steps = int(os.path.basename(global_step_folder).split("global_step_")[-1])
logger.info(
f"Rank {self.rank}: Resuming from checkpoint. "
f"Setting global_steps to {global_steps}."
)
except (ValueError, IndexError) as e:
raise ValueError(
f"Could not parse global step from checkpoint path: {global_step_folder}"
) from e
self._load_model_states(global_step_folder)
self._load_dataloader_state(global_step_folder)
dist.barrier(self.gather_group)
logger.info(f"Rank {self.rank}: Finished loading all checkpoint components.")
return global_steps
def _determine_checkpoint_path(self) -> Optional[str]:
"""Determine checkpoint path (rank 0 only)."""
if self.rank != 0:
return None
checkpoint_dir = self.config.trainer.default_local_dir
resume_from_path = self.config.trainer.resume_from_path
path_to_load = None
if self.config.trainer.resume_mode == "auto":
latest_path = find_latest_ckpt_path(checkpoint_dir)
if latest_path:
logger.info(f"Rank 0: Auto-found latest checkpoint at {latest_path}")
path_to_load = latest_path
elif self.config.trainer.resume_mode == "resume_path" and resume_from_path:
logger.info(f"Rank 0: Attempting to load from specified path: {resume_from_path}")
path_to_load = resume_from_path
if path_to_load and os.path.exists(path_to_load):
return path_to_load
else:
logger.warning(
f"Rank 0: Checkpoint path not found or invalid: '{path_to_load}'. "
f"Starting from scratch."
)
return None
def _load_model_states(self, global_step_folder: str) -> None:
"""Load model states for all trainable nodes."""
loaded_worker_keys = set()
for node in self.taskgraph.nodes.values():
if node.node_type != NodeType.MODEL_TRAIN:
continue
if node.node_role not in [NodeRole.ACTOR, NodeRole.CRITIC]:
continue
node_worker_key = generate_node_worker_key(node)
if node_worker_key in loaded_worker_keys:
continue
worker = self.workers[node_worker_key]
sub_dir_name = f"{node.node_role.name.lower()}_agent_{node.agent_group}"
checkpoint_path = os.path.join(global_step_folder, sub_dir_name)
if os.path.exists(checkpoint_path):
worker.load_checkpoint(
local_path=checkpoint_path,
del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
loaded_worker_keys.add(node_worker_key)
logger.debug(
f"Rank {self.rank}: Loaded {node.node_role.name} checkpoint for agent "
f"{node.agent_group} from {checkpoint_path}"
)
else:
logger.warning(
f"Rank {self.rank}: Checkpoint for agent {node.agent_group}'s "
f"{node.node_role.name} not found at {checkpoint_path}. "
f"Weights will be from initialization. "
f"If has multi-agent, will share the same checkpoint in agents"
)
def _load_dataloader_state(self, global_step_folder: str) -> None:
"""Load dataloader state for current DP group."""
_, dp_rank, _, _, _, _ = self._get_node_dp_info(self.first_rollout_node)
dataloader_path = os.path.join(global_step_folder, f"data_dp_rank_{dp_rank}.pt")
if os.path.exists(dataloader_path):
dataloader_state = torch.load(dataloader_path, map_location="cpu")
self.dataloader.load_state_dict(dataloader_state)
logger.debug(
f"Rank {self.rank} (DP_Rank {dp_rank}): Loaded dataloader state from "
f"{dataloader_path}"
)
else:
logger.warning(
f"Rank {self.rank} (DP_Rank {dp_rank}): Dataloader checkpoint not found at "
f"{dataloader_path}. Sampler state will not be restored, which may lead to "
f"data inconsistency."
)
================================================
FILE: siirl/dag_worker/constants.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
from siirl.execution.dag.node import NodeRole
class DAGInitializationError(Exception):
"""Custom exception for failures during DAGWorker initialization."""
pass
class DAGConstants:
"""Centralized constants to improve maintainability and avoid magic strings."""
# Worker role mapping
WORKER_ROLE_MAPPING: Dict[NodeRole, str] = {
NodeRole.ACTOR: "actor",
NodeRole.ROLLOUT: "rollout",
NodeRole.REFERENCE: "ref",
}
# Configuration keys
INTERN_CONFIG: str = "intern_config"
# Framework strategy names
FSDP_STRATEGIES: List[str] = ["fsdp", "fsdp2"]
MEGATRON_STRATEGYS: List[str] = ["megatron"]
# keep this for backward compatibility
MEGATRON_STRATEGY: str = "megatron"
# Metric group order
METRIC_GROUP_ORDER = ["step", "training", "actor", "critic", "perf", "response_length", "response", "prompt_length", "prompt", "dapo_sampling", "global_seqlen", "timing_s", "timing_per_token_ms", "perf/total_num_tokens", "perf/time_per_step", "perf/throughput"]
================================================
FILE: siirl/dag_worker/core_algos.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO-like algorithms.
"""
__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"]
import math
from collections import defaultdict
from enum import Enum
from typing import Any, Callable, Optional
import numpy as np
import torch
from omegaconf import DictConfig
from loguru import logger
import siirl.utils.model_utils.torch_functional as siirl_F
from siirl.params.model_args import AlgorithmArguments, ActorArguments
from siirl.execution.scheduler.enums import AdvantageEstimator
from tensordict import TensorDict
PolicyLossFn = Callable[
[
torch.Tensor, # old_log_prob
torch.Tensor, # log_prob
torch.Tensor, # advantages
torch.Tensor, # response_mask
str, # loss_agg_mode
Optional[DictConfig | AlgorithmArguments], # config
torch.Tensor | None, # rollout_log_probs
],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]
POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}
def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]:
"""Register a policy loss function with the given name.
Args:
name (str): The name to register the policy loss function under.
Returns:
function: Decorator function that registers the policy loss function.
"""
def decorator(func: PolicyLossFn) -> PolicyLossFn:
POLICY_LOSS_REGISTRY[name] = func
return func
return decorator
def get_policy_loss_fn(name):
"""Get the policy loss with a given name.
Args:
name: `(str)`
The name of the policy loss.
Returns:
`(callable)`: The policy loss function.
"""
loss_name = name
if loss_name not in POLICY_LOSS_REGISTRY:
raise ValueError(
f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}"
)
return POLICY_LOSS_REGISTRY[loss_name]
def compute_response_mask(data: TensorDict):
"""Compute the attention mask for the response part of the sequence.
Handles both 2D responses (NLP) and 3D responses (Embodied AI).
Returns:
torch.Tensor: The attention mask for the response tokens (always 2D).
"""
responses = data["responses"]
attention_mask = data["attention_mask"]
batch_size = responses.size(0)
# Handle 3D responses (Embodied AI): (batch_size, traj_len, action_token_len)
if responses.ndim == 3:
traj_len = responses.size(1)
action_token_len = responses.size(2)
# Check if attention_mask is also 3D
if attention_mask.ndim == 3:
# attention_mask: (batch_size, traj_len, tot_pad_len)
# Extract response part from last dimension: (batch_size, traj_len, action_token_len)
response_mask = attention_mask[:, :, -action_token_len:]
# Flatten to 2D: (batch_size, traj_len * action_token_len)
response_mask = response_mask.reshape(batch_size, -1)
else:
# attention_mask is 2D: (batch_size, total_length)
# Calculate flattened response_length and slice
response_length = traj_len * action_token_len
response_mask = attention_mask[:, -response_length:]
# Handle 2D responses (NLP): (batch_size, response_length)
elif responses.ndim == 2:
response_length = responses.size(1)
response_mask = attention_mask[:, -response_length:]
else:
raise ValueError(f"Unexpected responses shape: {responses.shape}, ndim={responses.ndim}")
return response_mask
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any:
"""Decorator to register a advantage estimator function with a given name.
Args:
name_or_enum: `(str)` or `(AdvantageEstimator)`
The name or enum of the advantage estimator.
"""
def decorator(fn):
name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn:
raise ValueError(
f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}"
)
ADV_ESTIMATOR_REGISTRY[name] = fn
return fn
return decorator
def get_adv_estimator_fn(name_or_enum):
"""Get the advantage estimator function with a given name.
Args:
name_or_enum: `(str)` or `(AdvantageEstimator)`
The name or enum of the advantage estimator.
Returns:
`(callable)`: The advantage estimator function.
"""
name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
if name not in ADV_ESTIMATOR_REGISTRY:
raise ValueError(f"Unknown advantage estimator simply: {name}")
return ADV_ESTIMATOR_REGISTRY[name]
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef, target_kl, horizon):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
"""Update the KL coefficient based on current KL divergence.
Args:
current_kl (float): Current KL divergence value.
n_steps (int): Number of steps taken.
"""
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current_kl, n_steps):
"""Update method for fixed KL controller (no-op).
Args:
current_kl (float): Current KL divergence value (unused).
n_steps (int): Number of steps taken (unused).
"""
pass
def get_kl_controller(kl_ctrl):
"""Factory function to create appropriate KL controller based on configuration.
Args:
kl_ctrl: Configuration object containing KL controller settings.
Returns:
KL controller instance (FixedKLController or AdaptiveKLController).
Raises:
NotImplementedError: If controller type is not supported.
AssertionError: If adaptive controller horizon is not positive.
"""
if kl_ctrl.type == "fixed":
return FixedKLController(kl_coef=kl_ctrl.kl_coef)
elif kl_ctrl.type == "adaptive":
assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
else:
raise NotImplementedError
@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae")
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
values: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma is `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
nextvalues = 0
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam_ = delta + gamma * lam * lastgaelam
# skip values and TD-error on observation tokens
nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues
lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = siirl_F.masked_whiten(advantages, response_mask)
return advantages, returns
@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo")
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgorithmArguments] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length)
index: `(np.ndarray)`
index array for grouping
epsilon: `(float)`
small value to avoid division by zero
norm_adv_by_std_in_grpo: `(bool)`
whether to scale the GRPO advantage
config: `(Optional[AlgorithmArguments])`
algorithm configuration object
Note:
If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
Returns:
advantages: `(torch.Tensor)`
shape is (bs, response_length)
Returns: `(torch.Tensor)`
shape is (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
if isinstance(index[i], torch.Tensor):
idx_key = index[i].item()
else:
idx_key = index[i]
id2score[idx_key].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if isinstance(index[i], torch.Tensor):
idx_key = index[i].item()
else:
idx_key = index[i]
if norm_adv_by_std_in_grpo:
scores[i] = (scores[i] - id2mean[idx_key]) / (id2std[idx_key] + epsilon)
else:
scores[i] = scores[i] - id2mean[idx_key]
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
def compute_marft_gae_advantage_return(
data: TensorDict,
pre_agent_group_ids,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""
Args:
data: `TensorDict`
pre_agent_group_ids: `List`
pre agent id
gamma: `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
key_prefix = "agent_group_"
async def compute_traj_adv():
pass
token_level_rewards = []
values = []
response_mask = []
advantages = []
returns = []
for agent_group in pre_agent_group_ids:
key = key_prefix + str(agent_group)
token_level_rewards.append(data.batch[key + "_token_level_rewards"])
values.append(data.batch[key + "_values"])
response_mask.append(data.batch[key + "_response_mask"])
advantages.append(torch.zeros_like(response_mask[-1]))
returns.append(torch.zeros_like(advantages[-1]))
token_level_rewards.append(data.batch["token_level_rewards"])
values.append(data.batch["values"])
response_mask.append(data.batch["response_mask"])
advantages.append(torch.zeros_like(response_mask[-1]))
returns.append(torch.zeros_like(advantages[-1]))
pre_agent_group_ids.append(pre_agent_group_ids[-1] + 1)
with torch.no_grad():
seen = set()
dp_start_bs = [
i for i, s in enumerate(data.non_tensor_batch["request_id"]) if s not in seen and not seen.add(s)
]
# loop all batch_size
for bs_id in dp_start_bs:
# last agent last traj last token
gae = 0
traj_len = data.non_tensor_batch["traj_len"][bs_id]
# loop each traj, tra has been reserved
for traj_idx in range(traj_len):
# loop each agent of traj
traj_bs_id = bs_id + traj_idx
# assert traj_idx == traj_len - data.non_tensor_batch["traj_step"][traj_bs_id] - 1,
# f'traj_idx {traj_idx}, traj_bs_id: {traj_bs_id}, traj_step
# {data.non_tensor_batch["traj_step"][traj_bs_id]}, traj_len {traj_len},
# request {data.non_tensor_batch["request_id"][traj_bs_id]},request_data {data} '
for agent_idx in reversed(pre_agent_group_ids):
gen_len = response_mask[agent_idx][traj_bs_id].sum()
# loop each token of agent
for t in reversed(range(gen_len)):
rew = token_level_rewards[agent_idx][traj_bs_id, t]
v = values[agent_idx][traj_bs_id, t]
if agent_idx == pre_agent_group_ids[-1]:
# last_agent
if t == gen_len - 1:
# last_token
if traj_idx == 0:
v_next = 0
else:
v_next = values[0][traj_bs_id - 1, 0]
delta = rew + gamma * v_next - v
gae = delta + gamma * lam * gae
else:
v_next = values[agent_idx][traj_bs_id, t + 1]
delta = gamma * v_next - v
gae = delta + gamma * lam * gae
else:
# not last agent
if t == gen_len - 1:
# last_token
v_next = values[agent_idx + 1][traj_bs_id, 0]
delta = rew + gamma * v_next - v
gae = delta + gamma * lam * gae
else:
v_next = values[agent_idx][traj_bs_id, t + 1]
delta = gamma * v_next - v
gae = delta + gamma * lam * gae
advantages[agent_idx][traj_bs_id, t] = gae
returns[agent_idx][traj_bs_id, t] = gae + v
for agent_idx in pre_agent_group_ids:
advantages[agent_idx] = siirl_F.masked_whiten(advantages[agent_idx], response_mask[agent_idx])
if agent_idx != pre_agent_group_ids[-1]:
data.batch[key_prefix + str(agent_group) + "_advantages"] = advantages[agent_idx]
data.batch[key_prefix + str(agent_group) + "_returns"] = returns[agent_idx]
else:
data.batch["advantages"] = advantages[agent_idx]
data.batch["returns"] = returns[agent_idx]
return
def compute_cpgd_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
weight_factor_in_cpgd: str = "STD_weight",
):
"""
Compute advantage for CPGD, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
weight_factor_in_cpgd: (str)
whether to use the STD weight as GRPO or clip_filter_like_weight.
choices: {STD_weight, clip_filter_like_weight, naive}
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if weight_factor_in_cpgd == "STD_weight":
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
elif weight_factor_in_cpgd == "clip_filter_like_weight":
count_no_0_adv = sum(v != 0 for v in id2std.values())
scores[i] = (scores[i] - id2mean[index[i]]) * (bsz / count_no_0_adv).clamp(max=3.0)
elif weight_factor_in_cpgd == "naive":
scores[i] = scores[i] - id2mean[index[i]]
else:
raise NotImplementedError
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
"""Compute token-level rewards with KL penalty.
Args:
token_level_scores (torch.Tensor): Token-level reward scores.
old_log_prob (torch.Tensor): Log probabilities from current policy.
ref_log_prob (torch.Tensor): Log probabilities from reference policy.
kl_ratio (float): KL penalty coefficient.
Returns:
torch.Tensor: Token-level rewards with KL penalty applied.
"""
kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
"""
Aggregate the loss matrix into a scalar.
Args:
loss_mat: `(torch.Tensor)`:
shape: (bs, response_length)
loss_mask: `(torch.Tensor)`:
shape: (bs, response_length)
loss_agg_mode: (str) choices:
method to aggregate the loss matrix into a scalar.
Returns:
loss: `a scalar torch.Tensor`
aggregated loss
"""
if loss_agg_mode == "token-mean":
loss = siirl_F.masked_mean(loss_mat, loss_mask)
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
loss = torch.mean(seq_losses) # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean
loss = torch.mean(seq_losses) # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor
# (loss_mask.shape[-1]) should ideally be constant
# throughout training to well-replicate the DrGRPO paper.
# TODO: Perhaps add user-defined normalizer argument to
# agg_loss to ensure divisor stays constant throughout.
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
return loss
@register_policy_loss("cpgd")
def compute_policy_loss_cpgd(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[ActorArguments] = None, # Use your config class
rollout_is_weights: torch.Tensor | None = None, # Keep signature consistent, but unused in this function
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the CPGD policy objective by directly clipping log_prob.
This function replicates the logic from the original siirl 'if use_cpgd_loss:' block.
Args:
old_log_prob: Log-probabilities under the old policy.
log_prob: Log-probabilities under the current policy.
advantages: Advantage estimates.
response_mask: Mask for valid tokens.
loss_agg_mode: Aggregation mode for the loss.
config: Configuration object containing clip ratios.
rollout_is_weights: Not used in this specific CPGD implementation.
Returns:
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
"""
assert config is not None, "Config must be provided for CPGD loss"
# --- Extract clip parameters from config ---
clip_ratio = config.clip_ratio
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = (
config.clip_ratio_c if config.clip_ratio_c is not None else 3.0
) # Needed only for pg_clipfrac_lower metric
assert clip_ratio_c > 1.0, f"clip_ratio_c ({clip_ratio_c}) must be > 1.0"
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)
clipped_log_prob = torch.where(
advantages > 0,
torch.clamp(log_prob, max=math.log(1 + clip_ratio_high) + old_log_prob),
torch.clamp(log_prob, min=math.log(1 - clip_ratio_low) + old_log_prob),
)
pg_losses = -clipped_log_prob * advantages
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
# Calculate clip fraction based on where the *ratio* would have been clipped
is_clipped = torch.where(advantages > 0, ratio > 1 + clip_ratio_high, ratio < 1 - clip_ratio_low)
pg_clipfrac = siirl_F.masked_mean(is_clipped.float(), response_mask).detach()
# Calculate lower clip fraction (dual clip metric)
pg_clipfrac_lower = siirl_F.masked_mean((ratio > clip_ratio_c) * (advantages < 0).float(), response_mask).detach()
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
def compute_policy_loss(
old_log_prob,
log_prob,
advantages,
response_mask,
cliprange=None,
cliprange_low=None,
cliprange_high=None,
clip_ratio_c=3.0,
loss_agg_mode: str = "token-mean",
use_cpgd_loss=False,
):
"""
Compute the clipped policy objective and related metrics for PPO.
Adapted from
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
cliprange (float, optional):
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
Defaults to None (must be provided).
cliprange_low (float, optional):
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
cliprange_high (float, optional):
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
clip_ratio_c (float, optional):
Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
Defaults to 3.0.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
use_cpgd_loss (bool):
whter to use the CPGD loss
"""
assert clip_ratio_c > 1.0, "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}."
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
if use_cpgd_loss:
clipped_log_prob = torch.where(advantages > 0, torch.clamp(log_prob, max=math.log(1 + cliprange_high) + old_log_prob), torch.clamp(log_prob, min=math.log(1 - cliprange_low) + old_log_prob))
pg_losses = -clipped_log_prob * advantages
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) # use token-mean
is_clipped = torch.where(advantages > 0, ratio > 1 + cliprange_high, ratio < 1 - cliprange_low)
pg_clipfrac = siirl_F.masked_mean(is_clipped.float(), response_mask).detach()
pg_clipfrac_lower = siirl_F.masked_mean((ratio > clip_ratio_c) * (advantages < 0).float(), response_mask).detach()
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = siirl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = siirl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("vanilla")
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[ActorArguments] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.
Adapted from
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
config: `(ActorArguments)`:
config for the actor.
rollout_log_probs: `(torch.Tensor)`:
log probabilities of actions under the rollout policy, shape (batch_size, response_length).
"""
assert config is not None
assert not isinstance(config, AlgorithmArguments)
clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = config.clip_ratio_c
cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high
assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)
negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
pg_losses2 = -advantages * torch.clamp(
ratio, 1 - cliprange_low, 1 + cliprange_high
) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(
pg_losses1, pg_losses2
) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = siirl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = siirl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("gspo")
def compute_policy_loss_gspo(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "seq-mean-token-mean",
config: Optional[ActorArguments] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for GSPO.
See https://arxiv.org/pdf/2507.18071 for more details.
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean-token-mean".
"""
assert config is not None
assert isinstance(config, ActorArguments)
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio
negative_approx_kl = log_prob - old_log_prob
# compute sequence-level importance ratio:
# si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) =
# exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adapted from
https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
Args:
log_prob: `(torch.Tensor)`
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
return:
pg_loss: `a scalar torch.Tensor`
policy gradient loss computed via GPG
"""
pg_losses = -log_prob * advantages
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
@register_policy_loss("clip_cov")
def compute_policy_loss_clip_cov(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[ActorArguments] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Adapted from
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
cliprange (float, optional):
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
Defaults to None (must be provided).
cliprange_low (float, optional):
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
cliprange_high (float, optional):
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
clip_cvo_ratio (float, optional):
Ratio for clipping the covariance. Defaults to 0.0002.
clip_cov_lb (float, optional):
Lower bound for clipping covariance. Defaults to 1.0.
clip_cov_ub (float, optional):
Upper bound for clipping covariance. Defaults to 5.0.
"""
assert config is not None
assert not isinstance(config, ActorArguments), "passing AlgoConfig not supported yet"
assert config.policy_loss is not None
clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002
cliprange = config.clip_ratio
cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange
cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange
clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0
clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0
assert clip_cov_ratio > 0, "clip_ratio should be larger than 0."
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
corr = torch.ones_like(advantages)
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)
cov_all = (advantages - siirl_F.masked_mean(advantages, response_mask)) * (
log_prob - siirl_F.masked_mean(log_prob.detach(), response_mask)
)
cov_all[response_mask == 0] = -torch.inf
cov_all[clip_by_origin] = -torch.inf
clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1)
top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)
top_k_idx = torch.nonzero(top_k_idx)
if len(top_k_idx) > 0:
perm = torch.randperm(len(top_k_idx))
top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]]
else:
top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)
corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0
pg_clipfrac = siirl_F.masked_mean((corr == 0).float(), response_mask)
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
@register_policy_loss("kl_cov")
def compute_policy_loss_kl_cov(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[ActorArguments] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Adapted from
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
kl_cov_ratio (float, optional):
Ratio for selecting the top-k covariance values. Defaults to 0.0002.
ppo_kl_coef (float, optional):
Coefficient for the KL penalty term in the loss. Defaults to 1.
"""
assert config is not None
assert not isinstance(config, ActorArguments), "passing AlgoConfig not supported yet"
assert config.policy_loss is not None
kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002
ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0
assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0."
negative_approx_kl = log_prob - old_log_prob
abs_kl = negative_approx_kl.abs()
ratio = torch.exp(negative_approx_kl)
ppo_kl_abs = siirl_F.masked_mean(negative_approx_kl.abs(), response_mask)
pg_losses1 = -advantages * ratio
pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl
pg_losses = pg_losses1
all_valid = response_mask > 0
all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0]
all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()
all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()
k = min(kl_cov_ratio, len(all_valid_adv))
if k != 0:
cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())
k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio))
large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices
if len(large_cov_idxs) != 0:
large_cov_idxs = all_valid_idx[large_cov_idxs]
pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[
large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]
]
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
@register_policy_loss("geo_mean")
def compute_policy_loss_geo_mean(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[ActorArguments] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for GMPO.
Adapted from paper https://arxiv.org/abs/2507.20673
https://github.com/callsys/GMPO/blob/main/train_zero_math_gmpo.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
not used
"""
assert config is not None
assert not isinstance(config, ActorArguments)
clip_ratio = config.clip_ratio # Clipping parameter. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability (uncomment it if you like)
# negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)
# Clipping at token-level & Clipping wider
sgn_advantage = torch.sign(advantages)
negative_approx_kl_clamp = torch.clamp(negative_approx_kl, -cliprange_low, cliprange_high)
negative_approx_kl_min = torch.min(sgn_advantage * negative_approx_kl, sgn_advantage * negative_approx_kl_clamp)
negative_approx_kl_min = sgn_advantage * negative_approx_kl_min
# Geometric-Mean Policy Optimization
response_mask_sum = response_mask.sum(dim=-1)
ratio = torch.exp((negative_approx_kl_min * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8))
# we only support sequence level advantage for now,
# otherwise, below would be not consistent with the paper
advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
pg_losses = -advantage * ratio
# Apply rollout importance sampling weights if provided
# For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level
if rollout_is_weights is not None:
# Aggregate token-level weights to sequence level using geometric mean for consistency
# Note: rollout_is_weights is always 2D regardless of rollout_is_level
seq_is_weights = torch.exp(
(torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
)
pg_losses = pg_losses * seq_is_weights
pg_loss = torch.mean(pg_losses)
# higher: ratio is too large that need clamp to clip_high (when adv > 0)
clipped = torch.ne(negative_approx_kl, negative_approx_kl_clamp)
pg_clipfrac = siirl_F.masked_mean((clipped * (advantages > 0)).float(), response_mask)
pg_clipfrac_lower = siirl_F.masked_mean((clipped * (advantages < 0)).float(), response_mask)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"):
"""Compute categorical entropy loss (For backward compatibility)
Args:
logits (torch.Tensor): shape is (bs, response_length, vocab_size)
response_mask (torch.Tensor): shape is (bs, response_length)
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
token_entropy = siirl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return entropy_loss
def compute_value_loss(
vpreds: torch.Tensor,
returns: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
cliprange_value: float,
loss_agg_mode: str = "token-mean",
):
"""
Compute the clipped value-function loss for PPO.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args:
vpreds (torch.FloatTensor):
Predicted values from the value head, shape (batch_size, response_length).
values (torch.FloatTensor):
Old (baseline) values from the value head, shape (batch_size, response_length).
returns (torch.FloatTensor):
Ground-truth returns, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the value loss calculation.
cliprange_value (float):
Clip range for value prediction updates.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
Returns:
vf_loss (torch.FloatTensor):
A scalar tensor containing the aggregated value-function loss.
vf_clipfrac (float):
Fraction of elements where the clipped loss was used.
"""
vpredclipped = siirl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
clipped_vf_losses = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
vf_clipfrac = siirl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
return vf_loss, vf_clipfrac
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other
kl penalty compute method for unbiased KL gradient estimation.
See more description in http://joschu.net/blog/kl-approx.html
Args:
logprob:
ref_logprob:
Returns:
kl_estimate
"""
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
return forward_score
"""
The expectation of k1 and k3 estimator is the expectaed value of KL, but the expected gradient of k1 and k3
estimator is not the expectaed gradient of KL. On the other hand k2 estimator gives right gradient estimator,
so we use a straight through trick here if the kl_penalty method ends with '+', .e.g., k3+.
"""
backward_score = 0.5 * (logprob - ref_logprob).square()
return backward_score - backward_score.detach() + forward_score.detach()
def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
See more description in http://joschu.net/blog/kl-approx.html
Args:
logprob:
ref_logprob:
Returns:
kl_estimate
"""
if kl_penalty in ("kl", "k1"):
return logprob - ref_logprob
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
if kl_penalty in ("mse", "k2"):
return 0.5 * (logprob - ref_logprob).square()
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
if kl_penalty in ("low_var_kl", "k3"):
kl = ref_logprob - logprob
# For numerical stability
kl = torch.clamp(kl, min=-20, max=20)
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError
raise NotImplementedError
def compute_pf_ppo_reweight_data(
data,
reweight_method: str = "pow",
weight_pow: float = 2.0,
):
"""Reweight the data based on the token_level_scores.
Args:
data: TensorDict object, containing batch, non_tensor_batch and meta_info
reweight_method: str, choices: "pow", "max_min", "max_random"
weight_pow: float, the power of the weight
Returns:
"""
@torch.no_grad()
def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor:
"""Compute importance weights for resampling based on scores.
Args:
scores (torch.Tensor): Tensor of scores to compute weights from.
reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random').
weight_pow (float): Power exponent for 'pow' method.
Returns:
torch.Tensor: Computed importance weights.
Raises:
ValueError: If reweight_method is not supported.
"""
if reweight_method == "pow":
weights = torch.pow(torch.abs(scores), weight_pow)
elif reweight_method == "max_min":
max_score = torch.max(scores)
min_score = torch.min(scores)
weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0)
elif reweight_method == "max_random":
max_score = torch.max(scores)
weights = torch.where(scores == max_score, 0.4, 0.1)
else:
raise ValueError(f"Unsupported reweight_method: {reweight_method}")
return weights
scores = data.batch["token_level_scores"].sum(dim=-1)
weights = compute_weights(scores, reweight_method, weight_pow)
weights = torch.clamp(weights + 1e-8, min=1e-8)
batch_size = scores.shape[0]
sample_indices = torch.multinomial(weights, batch_size, replacement=True)
resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()}
sample_indices_np = sample_indices.numpy()
resampled_non_tensor_batch = {}
for key, array in data.non_tensor_batch.items():
if isinstance(array, np.ndarray):
resampled_non_tensor_batch[key] = array[sample_indices_np]
else:
resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np]
resampled_meta_info = {}
for key, value in data.meta_info.items():
if isinstance(value, list) and len(value) == batch_size:
resampled_meta_info[key] = [value[i] for i in sample_indices_np]
else:
resampled_meta_info[key] = value
from copy import deepcopy
resampled_data = deepcopy(data)
resampled_data.batch = type(data.batch)(resampled_batch)
resampled_data.batch.batch_size = data.batch.batch_size
resampled_data.non_tensor_batch = resampled_non_tensor_batch
resampled_data.meta_info = resampled_meta_info
return resampled_data
def apply_kl_penalty(data: TensorDict, kl_ctrl: AdaptiveKLController, kl_penalty="kl", multi_turn=False):
"""Apply KL penalty to the token-level rewards.
This function computes the KL divergence between the reference policy and current policy,
then applies a penalty to the token-level rewards based on this divergence.
Args:
data (TensorDict): The data containing batched model outputs and inputs.
kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty.
kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl".
multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False.
Returns:
tuple: A tuple containing:
- The updated data with token-level rewards adjusted by KL penalty
- A dictionary of metrics related to the KL penalty
"""
responses = data["responses"]
token_level_scores = data["token_level_scores"]
batch_size = data.batch_size[0]
response_mask = data["response_mask"]
# compute kl between ref_policy and current policy
# When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
kld = kl_penalty(data["old_log_probs"], data["ref_log_prob"], kl_penalty=kl_penalty) # (batch_size, response_length)
kld = kld * response_mask
beta = kl_ctrl.value
token_level_rewards = token_level_scores - beta * kld
current_kl = siirl_F.masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()
# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
data.batch["token_level_rewards"] = token_level_rewards
metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta}
return data, metrics
def compute_advantage(data: TensorDict, adv_estimator, gamma=1.0, lam=1.0, norm_adv_by_std_in_grpo=True, weight_factor_in_cpgd="STD_weight", **kwargs):
"""Compute advantage estimates for policy optimization.
This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, CPGD, etc.
The advantage estimates are used to guide policy optimization in RL algorithms.
Args:
data (TensorDict): The data containing batched model outputs and inputs.
adv_estimator: The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++, CPGD).
gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.
lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.
num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.
multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False.
norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in GRPO. Defaults to True.
weight_factor_in_cpgd (str, optional): whether to use the STD weight as GRPO or clip_filter_like_weight. choices: {STD_weight, clip_filter_like_weight, naive}
Returns:
TensorDict: The updated data with computed advantages and returns.
"""
# Back-compatible with trainers that do not compute response mask in fit
if "response_mask" not in data.keys():
data.batch["response_mask"] = compute_response_mask(data)
# prepare response group
# TODO: add other ways to estimate advantages
if adv_estimator == AdvantageEstimator.GAE:
advantages, returns = compute_gae_advantage_return(
token_level_rewards=data["token_level_rewards"],
values=data["values"],
response_mask=data["response_mask"],
gamma=gamma,
lam=lam,
)
data["advantages"] = advantages
data["returns"] = returns
if kwargs.get("use_pf_ppo", False):
data = compute_pf_ppo_reweight_data(
data,
kwargs.get("pf_ppo_reweight_method", "pow"),
kwargs.get("pf_ppo_weight_pow", 2.0),
)
elif adv_estimator == AdvantageEstimator.GRPO:
if "finish_step" in data and data["responses"].ndim == 3:
# Embodied scenario: compute mask based on finish_step
responses = data["responses"]
batch_size = responses.size(0)
response_length = responses.size(1) * responses.size(2) # traj_len * action_token_len
# Get action_token_len from config or infer from responses shape
action_token_len = responses.size(2) # action token length
finish_step = data['finish_step'] * action_token_len
steps = torch.arange(response_length, device=responses.device)
steps_expanded = steps.unsqueeze(0).expand(batch_size, -1)
grpo_calculation_mask = steps_expanded < finish_step.unsqueeze(1) # (batch_size, traj_len)
logger.info(f"[GRPO] Using finish_step-based mask for embodied scenario")
else:
# NLP scenario or no finish_step: use attention_mask-based response_mask
grpo_calculation_mask = data["response_mask"]
logger.info(f"[GRPO] Using attention_mask-based response_mask for NLP scenario")
# Call compute_grpo_outcome_advantage with parameters matching its definition
advantages, returns = compute_grpo_outcome_advantage(
token_level_rewards=data["token_level_rewards"],
response_mask=grpo_calculation_mask,
index=data["uid"],
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
data["advantages"] = advantages
data["returns"] = returns
# Store the mask for consistent metrics calculation
data["response_mask"] = grpo_calculation_mask
logger.debug(f"[GRPO] Stored response_mask in batch for consistent metrics")
elif adv_estimator == AdvantageEstimator.CPGD:
cpgd_calculation_mask = data["response_mask"]
# Call compute_cpgd_outcome_advantage with parameters matching its definition
advantages, returns = compute_grpo_outcome_advantage(
token_level_rewards=data["token_level_rewards"],
response_mask=cpgd_calculation_mask,
index=data["uid"],
weight_factor_in_cpgd=weight_factor_in_cpgd,
)
data["advantages"] = advantages
data["returns"] = returns
elif adv_estimator == AdvantageEstimator.GAE_MARFT:
compute_marft_gae_advantage_return(
data,
pre_agent_group_ids=kwargs["agent_group_ids"],
gamma=gamma,
lam=lam,
)
else:
raise NotImplementedError
return data
================================================
FILE: siirl/dag_worker/dag_utils.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions for DAG worker operations.
"""
import os
import ray
import torch
import inspect
import json
import time
import csv
import hashlib
import numpy as np
import torch.distributed as dist
from contextlib import contextmanager
from datetime import datetime
from zoneinfo import ZoneInfo
from pathlib import Path
from collections import deque
from tensordict import TensorDict
from typing import Dict, Optional, Type, List, Any, Tuple, Union
from loguru import logger
from tensordict import TensorDict
from siirl.execution.dag.node import Node, NodeType, NodeRole
from siirl.execution.dag import TaskGraph
from siirl.utils.extras.device import get_device_name, device_synchronize
from siirl.engine.base_worker import Worker
from siirl.utils.import_string import import_string
from siirl.dag_worker.constants import DAGConstants
from siirl.dag_worker.data_structures import ValidationResult
from siirl.dag_worker.metric_aggregator import (
DistributedMetricAggregator,
_ReduceOp
)
# ==========================================================================================
# Section 1: Performance & Timing
# ==========================================================================================
@contextmanager
def timer(enable_perf: bool, name: str, timing_dict: dict):
"""Measures execution time of a code block and stores in timing_dict."""
if enable_perf:
device_synchronize()
start_time = time.perf_counter()
yield
if enable_perf:
device_synchronize()
end_time = time.perf_counter()
timing_dict[name] = timing_dict.get(name, 0) + end_time - start_time
def add_prefix_to_dataproto(tensordict: TensorDict, node: Node):
"""
Adds a prefix to all keys in the TensorDict.
The prefix is formatted as f"agent_group_{node.agent_group}_".
Only keys that do not already have a prefix will be modified.
Args:
data_proto (TensorDict): The TensorDict instance.
node (Node): The node containing the agent_group.
"""
prefix = f"agent_group_{node.agent_group}_"
prefix_agent_group = "agent_group_"
# Process tensor batch
if tensordict is not None:
new_batch = {}
for key, value in tensordict.items():
if not key.startswith(prefix_agent_group):
new_key = prefix + key
new_batch[new_key] = value
else:
new_batch[key] = value
tensordict = TensorDict(new_batch, batch_size=tensordict.batch_size)
return tensordict
def remove_prefix_from_dataproto(tensordict, node: Node):
"""
Removes the prefix from all keys in the TensorDict.
Only keys with a matching prefix will have the prefix removed.
Args:
data_proto (TensorDict): The TensorDict instance.
node (Node): The node containing the agent_group to identify the prefix.
"""
prefix = f"agent_group_{node.agent_group}_"
prefix_len = len(prefix)
# Process tensor batch
if tensordict is not None:
new_batch = {}
for key, value in tensordict.items():
if key.startswith(prefix):
new_key = key[prefix_len:]
new_batch[new_key] = value
else:
new_batch[key] = value
tensordict = TensorDict(new_batch, batch_size=tensordict.batch_size)
return tensordict
def add_prefix_to_metrics(metrics: dict, node: Node) -> dict:
"""Adds agent prefix to all metric keys for multi-agent isolation."""
prefix = f"agent_{node.agent_group}_"
prefix_agent_group = "agent_"
if metrics:
new_metrics = {}
for key, value in metrics.items():
if not key.startswith(prefix_agent_group):
new_key = prefix + key
new_metrics[new_key] = value
else:
new_metrics[key] = value
metrics = new_metrics
return metrics
# ==========================================================================================
# Section 3: Initialization & Setup
# ==========================================================================================
def get_and_validate_rank() -> int:
"""Retrieves and validates worker rank from RANK environment variable."""
rank_str = os.environ.get("RANK")
if rank_str is None:
raise ValueError("Environment variable 'RANK' is not set. This is required for distributed setup.")
try:
return int(rank_str)
except ValueError as e:
raise ValueError(f"Invalid RANK format: '{rank_str}'. Must be an integer.") from e
def get_taskgraph_for_rank(rank: int, taskgraph_mapping: Dict[int, TaskGraph]) -> TaskGraph:
"""Retrieves TaskGraph for current rank from mapping."""
if rank not in taskgraph_mapping:
raise ValueError(f"Rank {rank} not found in the provided taskgraph_mapping.")
taskgraph = taskgraph_mapping[rank]
if not isinstance(taskgraph, TaskGraph):
raise TypeError(f"Object for rank {rank} must be a TaskGraph, but got {type(taskgraph).__name__}.")
logger.info(f"Rank {rank} assigned to TaskGraph with ID {taskgraph.graph_id}.")
return taskgraph
def log_ray_actor_info(rank: int):
"""Logs Ray actor context information for debugging."""
try:
ctx = ray.get_runtime_context()
logger.debug(
f"Ray Actor Context for Rank {rank}: ActorID={ctx.get_actor_id()}, JobID={ctx.get_job_id()}, "
f"NodeID={ctx.get_node_id()}, PID={os.getpid()}"
)
except RuntimeError:
logger.warning(f"Rank {rank}: Not running in a Ray actor context.")
def log_role_worker_mapping(role_worker_mapping: Dict[NodeRole, Type[Worker]]):
"""Logs role-to-worker class mapping for verification."""
if not role_worker_mapping:
logger.error("Role-to-worker mapping is empty after setup. This will cause execution failure.")
return
logger.debug("--- [Role -> Worker Class] Mapping ---")
max_len = max((len(r.name) for r in role_worker_mapping.keys()), default=0)
for role, worker_cls in sorted(role_worker_mapping.items(), key=lambda item: item[0].name):
logger.debug(
f" {role.name:<{max_len}} => {worker_cls.__name__} (from {inspect.getmodule(worker_cls).__name__})"
)
logger.debug("--------------------------------------")
# ==========================================================================================
# Section 4: Worker Management
# ==========================================================================================
def find_first_non_compute_ancestor(taskgraph: TaskGraph, start_node_id: str) -> Optional[Node]:
"""Finds first ancestor node that is not COMPUTE type using BFS."""
start_node = taskgraph.get_node(start_node_id)
if not start_node:
logger.warning(f"Could not find start node '{start_node_id}' in the graph.")
return None
if start_node.node_type != NodeType.COMPUTE:
return start_node
queue = deque(start_node.dependencies)
visited = set(start_node.dependencies)
node_id = start_node_id
while queue:
logger.debug(f"try find dependency node with ID '{node_id}' during upward search")
node_id = queue.popleft()
node = taskgraph.get_node(node_id)
if not node:
logger.warning(f"Could not find dependency node with ID '{node_id}' during upward search.")
continue
if node.node_type != NodeType.COMPUTE:
return node
for dep_id in node.dependencies:
if dep_id not in visited:
visited.add(dep_id)
queue.append(dep_id)
return None
def should_create_worker(role_worker_mapping: Dict[NodeRole, Type[Worker]], node: Node) -> bool:
"""Determines if worker instance should be created for a given node."""
if node.agent_options and node.agent_options.share_instance:
# Worker already initialized in target agent node
return False
return node.node_type in [NodeType.MODEL_TRAIN, NodeType.MODEL_INFERENCE] and node.node_role in role_worker_mapping
def generate_node_worker_key(node: Node) -> str:
"""Generates unique key for node's worker instance."""
return f"{node.agent_group}_{node.node_type.value}_{node.node_role.value}"
def setup_sharding_manager(
config,
agent_group_process_group: Dict,
agent_group: int,
worker_dict: Dict[NodeRole, Worker]
):
"""Configures sharding manager to sync weights between training and inference backends."""
actor_worker = worker_dict[NodeRole.ACTOR]
rollout_worker = worker_dict[NodeRole.ROLLOUT]
rollout_pg = agent_group_process_group[agent_group][NodeRole.ROLLOUT]
if config.actor_rollout_ref.model.model_type == "embodied":
if hasattr(actor_worker, "actor_module_fsdp"):
rollout_worker.rollout.model = actor_worker.actor_module_fsdp
logger.info(f"[Embodied] Set module for EmbodiedHFRollout for agent group {agent_group}.")
else:
logger.error(f"[Embodied] Actor worker for agent group {agent_group} does not have 'actor_module_fsdp'.")
rollout_pg = agent_group_process_group[agent_group][NodeRole.ROLLOUT]
parallel_config = {
"rollout_parallel_size": rollout_worker.config.rollout.tensor_model_parallel_size,
"rollout_world_size": dist.get_world_size(rollout_pg),
"rollout_rank": dist.get_rank(rollout_pg),
}
device_name = get_device_name()
layer_name_mapping = {
"qkv_layer_name": "self_attention.linear_qkv.",
"gate_proj_layer_name": "linear_fc1.weight",
}
# Lazy import and deferred execution mapping
sharding_manager_map = {
("fsdp", "hf"): (
"siirl.engine.sharding_manager.fsdp_hf.FSDPHFShardingManager",
lambda: {
"module": actor_worker.actor_module_fsdp,
"rollout": rollout_worker.rollout,
"offload_param": getattr(actor_worker, "_is_offload_param", False),
"offload_embedding": (
getattr(rollout_worker.config, "embodied", None) is not None and
getattr(rollout_worker.config.embodied, "embedding_model_offload", False)),
},
),
("fsdp", "vllm"): (
"siirl.engine.sharding_manager.fsdp_vllm.MultiAgentFSDPVLLMShardingManager",
lambda: {
"module": actor_worker.actor_module_fsdp,
"inference_engine": rollout_worker.rollout.inference_engine,
"model_config": actor_worker.actor_model_config,
"parallel_config": parallel_config,
"full_params": "hf" in rollout_worker.config.rollout.load_format,
"offload_param": getattr(actor_worker, "_is_offload_param", False),
},
),
("fsdp", "sglang"): (
"siirl.engine.sharding_manager.fsdp_sglang.MultiAgentFSDPSGLangShardingManager",
lambda: {
"module": actor_worker.actor_module_fsdp,
"inference_engine": rollout_worker.rollout.inference_engine,
"model_config": actor_worker.actor_model_config,
"device_mesh": torch.distributed.init_device_mesh(
device_name,
mesh_shape=(
parallel_config.get("rollout_world_size") // parallel_config.get("rollout_parallel_size"),
parallel_config.get("rollout_parallel_size"),
),
mesh_dim_names=["dp", "infer_tp"],
),
"rollout_config": rollout_worker.config.rollout,
"full_params": "hf" in rollout_worker.config.rollout.load_format,
"offload_param": getattr(actor_worker, "_is_offload_param", False),
"multi_stage_wake_up": rollout_worker.config.rollout.multi_stage_wake_up,
},
),
("megatron", "vllm"): (
"siirl.engine.sharding_manager.megatron_vllm.MultiAgentMegatronVLLMShardingManager",
lambda: {
"actor_module": actor_worker.actor_module,
"inference_engine": rollout_worker.rollout.inference_engine,
"model_config": actor_worker.actor_model_config,
"rollout_config": rollout_worker.config.rollout,
"transformer_config": actor_worker.tf_config,
"layer_name_mapping": layer_name_mapping,
"weight_converter": get_mcore_weight_converter(actor_worker.actor_model_config, actor_worker.dtype),
"device_mesh": rollout_worker.device_mesh,
"offload_param": actor_worker._is_offload_param,
"bridge": actor_worker.bridge,
},
),
("megatron", "sglang"): (
"siirl.engine.sharding_manager.megatron_sglang.MultiAgentMegatronSGLangShardingManager",
lambda: {
"actor_module": actor_worker.actor_module,
"inference_engine": rollout_worker.rollout.inference_engine,
"model_config": actor_worker.actor_model_config,
"rollout_config": rollout_worker.config.rollout,
"transformer_config": actor_worker.tf_config,
"layer_name_mapping": layer_name_mapping,
"weight_converter": get_mcore_weight_converter(actor_worker.actor_model_config, actor_worker.dtype),
"device_mesh": torch.distributed.init_device_mesh(
device_name,
mesh_shape=(
parallel_config.get("rollout_world_size") // parallel_config.get("rollout_parallel_size"),
parallel_config.get("rollout_parallel_size"),
),
mesh_dim_names=["dp", "infer_tp"],
),
"offload_param": getattr(actor_worker, "_is_offload_param", False),
"bridge": actor_worker.bridge,
},
),
}
strategy = actor_worker.config.actor.strategy.lower()
if strategy == DAGConstants.MEGATRON_STRATEGY:
from siirl.models.mcore import get_mcore_weight_converter
rollout_name = config.actor_rollout_ref.rollout.name.lower()
if (strategy, rollout_name) not in sharding_manager_map:
raise NotImplementedError(f"Unsupported sharding manager configuration: {strategy=}, {rollout_name=}")
sharding_manager_cls_str, kwargs_builder = sharding_manager_map[(strategy, rollout_name)]
sharding_manager_cls = import_string(sharding_manager_cls_str)
sharding_manager = sharding_manager_cls(**kwargs_builder())
rollout_worker.set_rollout_sharding_manager(sharding_manager)
logger.debug(f"Set up {sharding_manager_cls.__name__} for agent group {agent_group}.")
def get_worker_classes(config, strategy: str) -> Dict[NodeRole, Type[Worker]]:
"""Dynamically imports worker classes based on specified training strategy."""
if strategy in DAGConstants.FSDP_STRATEGIES:
from siirl.engine.fsdp_workers import (
ActorRolloutRefWorker,
AsyncActorRolloutRefWorker,
CriticWorker,
RewardModelWorker,
)
actor_cls = (
AsyncActorRolloutRefWorker
if config.actor_rollout_ref.rollout.mode == "async"
else ActorRolloutRefWorker
)
return {
NodeRole.ACTOR: actor_cls,
NodeRole.ROLLOUT: actor_cls,
NodeRole.REFERENCE: actor_cls,
NodeRole.CRITIC: CriticWorker,
NodeRole.REWARD: RewardModelWorker,
}
elif strategy in DAGConstants.MEGATRON_STRATEGYS:
from siirl.engine.megatron_workers import (
ActorWorker,
RolloutWorker,
AsyncRolloutWorker,
ReferenceWorker,
CriticWorker,
RewardModelWorker
)
is_async_mode = config.actor_rollout_ref.rollout.mode == "async"
return {
NodeRole.ACTOR: ActorWorker,
NodeRole.ROLLOUT: AsyncRolloutWorker if is_async_mode else RolloutWorker,
NodeRole.REFERENCE: ReferenceWorker,
NodeRole.CRITIC: CriticWorker,
NodeRole.REWARD: RewardModelWorker
}
raise NotImplementedError(f"Strategy '{strategy}' is not supported.")
def get_parallelism_config(reference_node: Node) -> tuple[int, int]:
"""Extracts tensor parallel (TP) and pipeline parallel (PP) sizes from node config."""
tp_size = 1
pp_size = 1
if intern_config := reference_node.config.get(DAGConstants.INTERN_CONFIG):
if reference_node.node_type == NodeType.MODEL_INFERENCE:
# Rollout nodes: only TP supported (PP not typically used for inference)
tp_size = intern_config.rollout.tensor_model_parallel_size
pp_size = 1
elif reference_node.node_type == NodeType.MODEL_TRAIN:
# Extract strategy from config
strategy = 'fsdp' # default
if hasattr(intern_config, 'actor') and hasattr(intern_config.actor, 'strategy'):
strategy = intern_config.actor.strategy
elif hasattr(intern_config, 'strategy'):
strategy = intern_config.strategy
if strategy in DAGConstants.MEGATRON_STRATEGYS:
# Megatron supports both TP and PP
if hasattr(intern_config, 'actor') and hasattr(intern_config.actor, 'megatron'):
tp_size = intern_config.actor.megatron.tensor_model_parallel_size
pp_size = intern_config.actor.megatron.pipeline_model_parallel_size
elif hasattr(intern_config, 'megatron'):
tp_size = intern_config.megatron.tensor_model_parallel_size
pp_size = intern_config.megatron.pipeline_model_parallel_size
else:
# FSDP: no TP/PP, keep TP=PP=1
tp_size = 1
pp_size = 1
return tp_size, pp_size
def prepare_generation_batch(batch: TensorDict) -> TensorDict:
"""Pops keys from a batch to isolate data needed for sequence generation."""
keys_to_pop = ["input_ids", "attention_mask", "position_ids", "raw_prompt_ids"]
if "multi_modal_inputs" in batch:
keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"])
if "tools_kwargs" in batch:
keys_to_pop.append("tools_kwargs")
if "raw_prompt" in batch:
keys_to_pop.append("raw_prompt")
if "interaction_kwargs" in batch:
keys_to_pop.append("interaction_kwargs")
return batch.pop(
)
def prepare_local_batch_metrics(batch: TensorDict, use_critic: bool = True) -> Dict[str, torch.Tensor]:
"""Extracts raw metric tensors from batch for distributed aggregation."""
from siirl.utils.metrics.metric_utils import _compute_response_info
response_info = _compute_response_info(batch)
response_mask = response_info["response_mask"].bool()
device = batch["advantages"].device
max_response_length = batch["responses"].shape[-1]
response_lengths = response_info["response_length"].to(device)
prompt_lengths = response_info["prompt_length"].to(device)
# Components for correct/wrong response length metrics
correct_threshold = 0.5
rewards_per_response = batch["token_level_rewards"].sum(-1)
correct_mask = rewards_per_response > correct_threshold
# Components for prompt clip ratio
prompt_attn_mask = batch["attention_mask"][:, :-max_response_length]
max_prompt_length = prompt_attn_mask.size(-1)
# Prepare raw metric values
local_data = {
"score": batch["token_level_scores"].sum(-1),
"rewards": batch["token_level_rewards"].sum(-1),
"advantages": torch.masked_select(batch["advantages"], response_mask),
"returns": torch.masked_select(batch["returns"], response_mask),
"response_length": response_info["response_length"].to(device),
"prompt_length": response_info["prompt_length"].to(device),
"correct_response_length": response_lengths[correct_mask],
"wrong_response_length": response_lengths[~correct_mask],
"response_clip_ratio": torch.eq(response_info["response_length"], max_response_length).float(),
"prompt_clip_ratio": torch.eq(prompt_lengths, max_prompt_length).float(),
}
if use_critic:
valid_values = torch.masked_select(batch["values"], response_mask)
error = local_data["returns"] - valid_values
critic_data = {
"values": valid_values,
# Special components for explained variance (summed globally)
"returns_sq_sum_comp": torch.sum(torch.square(local_data["returns"])),
"error_sum_comp": torch.sum(error),
"error_sq_sum_comp": torch.sum(torch.square(error)),
}
local_data.update(critic_data)
return local_data
def whether_put_data(rank, is_current_last_pp_tp_rank0, next_dp_size, cur_dp_size, cur_node, next_node) -> bool:
# Determine whether to put data into buffer based on node configuration
result = False
reason = "No condition met"
if is_current_last_pp_tp_rank0:
result = True
reason = "Current last PP rank's TP rank 0"
elif next_dp_size == cur_dp_size:
if next_node.node_type in [NodeType.COMPUTE, NodeType.MODEL_TRAIN]:
result = True
reason = f"DP sizes match and next node is {next_node.node_type}"
elif cur_node.node_role == next_node.node_role and cur_node.node_role == NodeRole.ROLLOUT:
result = True
reason = "Both nodes are ROLLOUT"
logger.debug(f"Rank {rank}: _whether_put_data decision for {cur_node.node_id}->{next_node.node_id}: {result} ({reason}). "
f"is_current_last_pp_tp_rank0={is_current_last_pp_tp_rank0}, next_dp_size={next_dp_size}, cur_dp_size={cur_dp_size}, "
f"cur_node_type={cur_node.node_type}, next_node_type={next_node.node_type}, "
f"cur_node_role={cur_node.node_role}, next_node_role={next_node.node_role}")
return result
# ==========================================================================================
# Section 6: Metrics Collection & Aggregation
# ==========================================================================================
def reduce_and_broadcast_metrics(
local_metrics: Dict[str, Union[float, List[float], torch.Tensor]],
group: dist.ProcessGroup
) -> Dict[str, float]:
"""Aggregates metrics across all ranks using all_reduce operations."""
if not isinstance(local_metrics, dict) or not local_metrics:
return {}
world_size = dist.get_world_size(group)
if world_size <= 1:
# Non-distributed case: perform local aggregation only
aggregator = DistributedMetricAggregator(local_metrics, group=None)
final_metrics = {}
for op_type, data in aggregator.op_buckets.items():
for key, value in data:
if op_type == _ReduceOp.SUM: # value is a (sum, count) tuple
final_metrics[key] = value[0] / value[1] if value[1] > 0 else 0.0
else: # value is a float
final_metrics[key] = float(value)
return final_metrics
# Pipeline Parallel: ensure all ranks have same metric keys
# 1. Gather all metric keys from all ranks
local_keys = set(local_metrics.keys())
all_keys_list = [None] * world_size
dist.all_gather_object(all_keys_list, local_keys, group=group)
# 2. Union all keys to get complete set
all_expected_keys = set()
for keys_set in all_keys_list:
all_expected_keys.update(keys_set)
# 3. Aggregate with unified keys
aggregator = DistributedMetricAggregator(local_metrics, group)
aggregator.op_buckets = aggregator._bucket_local_metrics(local_metrics, all_expected_keys)
return aggregator.aggregate_and_get_results()
def format_metrics_by_group(metrics: Dict[str, Any], group_order: List[str]) -> Dict[str, Any]:
"""Reorders metrics by group prefixes and alphabetically within groups."""
if not metrics:
return {}
ordered_dict = {}
processed_keys = set()
# Pre-identify explicitly mentioned full keys
explicitly_mentioned_keys = {key for key in group_order if key in metrics}
# Process metrics according to group/key order
for pattern in group_order:
# Check if pattern is a full key
if pattern in explicitly_mentioned_keys and pattern not in processed_keys:
ordered_dict[pattern] = metrics[pattern]
processed_keys.add(pattern)
else:
# Treat as group prefix
group_prefix = f"{pattern}/"
# Find all keys in this group and sort alphabetically
keys_in_group = sorted(
[
key
for key in metrics
if key.startswith(group_prefix)
and key not in processed_keys
and key not in explicitly_mentioned_keys
]
)
for key in keys_in_group:
ordered_dict[key] = metrics[key]
processed_keys.add(key)
# Process remaining keys
remaining_keys = sorted([key for key in metrics if key not in processed_keys])
if remaining_keys:
for key in remaining_keys:
ordered_dict[key] = metrics[key]
return ordered_dict
# ==========================================================================================
# Section 7: Logging & Output
# ==========================================================================================
def log_metrics_to_console(rank: int, ordered_metrics: List[Tuple[str, Any]], step: int):
"""Logs formatted metrics string to console (rank 0 only)."""
if rank != 0:
return
log_parts = [f"step:{step}"]
log_parts.extend([f"{k}:{v:.4f}" if isinstance(v, float) else f"{k}:{v}" for k, v in ordered_metrics])
logger.info(" | ".join(log_parts))
def dump_validation_generations(
config,
global_steps: int,
rank: int,
results: List[ValidationResult]
):
"""Dumps validation generation results to rank-specific JSON file."""
dump_path_str = config.trainer.rollout_data_dir
if not dump_path_str:
return
dump_path = Path(dump_path_str)
try:
dump_path.mkdir(parents=True, exist_ok=True)
filename = dump_path / f"step_{global_steps}_rank_{rank}.json"
# Collect entries
entries = []
for res in results:
entry = {
"rank": rank,
"global_step": global_steps,
"data_source": res.data_source,
"input": res.input_text,
"output": res.output_text,
"score": res.score,
}
if res.extra_rewards:
entry.update(res.extra_rewards)
entries.append(entry)
# Write with pretty formatting
with open(filename, "w", encoding="utf-8") as f:
json.dump(entries, f, ensure_ascii=False, indent=4)
if rank == 0:
logger.info(f"Validation generations are being dumped by all ranks to: {dump_path.resolve()}")
logger.debug(f"Rank {rank}: Dumped {len(results)} validation generations to {filename}")
except (OSError, IOError) as e:
logger.error(f"Rank {rank}: Failed to write validation dump file to {dump_path}: {e}")
except Exception as e:
logger.error(f"Rank {rank}: An unexpected error occurred during validation dumping: {e}", exc_info=True)
def aggregate_and_write_performance_metrics(
gather_group,
rank,
global_steps,
config,
metrics: Dict[str, Any]):
"""
Gathers performance metrics from all ranks to rank 0 and writes them to a CSV file.
Each row corresponds to a metric key COMMON to all ranks, and each column to a rank.
This function is called only if performance profiling is enabled.
"""
# Gather all metrics dictionaries to rank 0
world_size = dist.get_world_size()
gathered_metrics = [None] * world_size if rank == 0 else None
dist.gather_object(metrics, gathered_metrics, dst=0, group=gather_group)
if rank == 0:
if not gathered_metrics:
logger.warning("No metrics gathered on rank 0. Skipping performance CSV write.")
return
valid_metrics = [m for m in gathered_metrics if isinstance(m, dict) and m]
if not valid_metrics:
logger.warning("No valid metric dictionaries received on rank 0. Skipping CSV write.")
return
common_keys = set(valid_metrics[0].keys())
for rank_metrics in valid_metrics[1:]:
common_keys.intersection_update(rank_metrics.keys())
sorted_keys = sorted(list(common_keys))
if not sorted_keys:
logger.warning(
f"No common metric keys found across all ranks for step {global_steps}. Skipping CSV write."
)
return
ts = get_time_now().strftime("%Y-%m-%d-%H-%M-%S")
try:
# Try to get model name from model path config
model_name = os.path.basename(os.path.normpath(config.actor_rollout_ref.model.path))
output_dir = os.path.join("performance_logs", model_name, ts)
os.makedirs(output_dir, exist_ok=True)
except OSError as e:
logger.error(f"Failed to create performance log directory {output_dir}: {e}")
return
filename = os.path.join(output_dir, f"world_{world_size}_step_{global_steps}_common_metrics.csv")
try:
with open(filename, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
header = (
["metric"]
+ [f"rank_{i}" for i in range(world_size)]
+ ["max", "min", "delta_max_min", "delta_max_rank_0"]
)
writer.writerow(header)
for key in sorted_keys:
row = [key]
for i in range(world_size):
rank_metrics = gathered_metrics[i]
if isinstance(rank_metrics, dict):
value = rank_metrics.get(key, "Error: Key Missing")
else:
value = "N/A: Invalid Data"
row.append(value)
row_max = max([x for x in row[1:] if isinstance(x, (int, float))], default="N/A")
row_min = min([x for x in row[1:] if isinstance(x, (int, float))], default="N/A")
row_delta_max = (
row_max - row_min
if isinstance(row_max, (int, float)) and isinstance(row_min, (int, float))
else "N/A"
)
row_delta_rank0 = row_max - row[1] if isinstance(row[1], (int, float)) else "N/A"
row.extend([row_max, row_min, row_delta_max, row_delta_rank0])
writer.writerow(row)
logger.info(
f"Common performance metrics for step {global_steps} successfully written to {filename}"
)
except OSError as e:
logger.error(f"Failed to write performance metrics to CSV file {filename}: {e}")
def log_core_performance_metrics(rank: int, enable_perf: bool, metrics: Dict[str, Any], step: int):
"""
Logs a formatted, easy-to-read summary of core performance metrics on rank 0.
This provides a clear, separate view of the most important indicators.
"""
if rank != 0:
return
def get_metric(key, precision=3):
val = metrics.get(key)
if val is None:
return "N/A"
if isinstance(val, (float, np.floating)):
return f"{val:.{precision}f}"
return val
# --- Build the log string ---
log_str = f"\n\n{'=' * 25} RANK({rank}): Core Performance Metrics (Step: {step}) {'=' * 25}\n"
# --- Overall Performance ---
log_str += "\n--- ⏱️ Overall Performance ---\n"
log_str += f" {'Step Time':<28}: {get_metric('perf/time_per_step', 3)} s\n"
log_str += f" {'Throughput (tokens/s)':<28}: {get_metric('perf/throughput', 2)}\n"
log_str += f" {'Total Tokens in Step':<28}: {get_metric('perf/total_num_tokens', 0)}\n"
# --- Algorithm-Specific Metrics ---
log_str += "\n--- 📈 Algorithm Metrics ---\n"
log_str += f" {'Actor Entropy':<28}: {get_metric('actor/entropy_loss', 4)}\n"
log_str += (
f" {'Critic Rewards (Mean/Min/Max)':<28}: {get_metric('critic/rewards/mean', 3)} / "
f"{get_metric('critic/rewards/min', 3)} / {get_metric('critic/rewards/max', 3)}\n"
)
log_str += (
f" {'Critic Scores (Mean/Min/Max)':<28}: {get_metric('critic/score/mean', 3)} / "
f"{get_metric('critic/score/min', 3)} / {get_metric('critic/score/max', 3)}\n"
)
if enable_perf:
# --- Module-wise Timings (Single Column) ---
log_str += "\n--- ⏳ Module-wise Timings (s) ---\n"
# Dynamically find all delta_time metrics except the total step time
timing_keys = sorted(
[k for k in metrics.keys() if k.startswith("perf/delta_time/") and k != "perf/delta_time/step"]
)
ref_key = "perf/delta_time/ref"
reference_key = "perf/delta_time/reference"
if ref_key in timing_keys and reference_key in timing_keys:
timing_keys.remove(reference_key)
if timing_keys:
# Find the maximum label length across all keys for clean alignment
max_label_len = 0
if timing_keys:
max_label_len = max(
len(k.replace("perf/delta_time/", "").replace("_", " ").title()) for k in timing_keys
)
for key in timing_keys:
label = key.replace("perf/delta_time/", "").replace("_", " ").title()
value = get_metric(key, 3)
log_str += f" {label:<{max_label_len}} : {value}s\n"
else:
log_str += " No detailed timing metrics available.\n"
# --- Model Flops Utilization (MFU) ---
log_str += "\n--- 🔥 Model Flops Utilization (MFU) ---\n"
log_str += f" {'Mean MFU':<28}: {get_metric('perf/mfu/mean', 3)}\n"
log_str += f" {'Actor Training MFU':<28}: {get_metric('perf/mfu/actor', 3)}\n"
# log_str += f" {'Rollout MFU':<28}: {get_metric('perf/mfu/rollout', 3)}\n"
log_str += f" {'Reference Policy MFU':<28}: {get_metric('perf/mfu/ref', 3)}\n"
log_str += f" {'Actor LogProb MFU':<28}: {get_metric('perf/mfu/actor_log_prob', 3)}\n"
# --- Memory Usage ---
log_str += "\n--- 💾 Memory Usage ---\n"
log_str += f" {'Max GPU Memory Allocated':<28}: {get_metric('perf/max_memory_allocated_gb', 2)} GB\n"
log_str += f" {'Max GPU Memory Reserved':<28}: {get_metric('perf/max_memory_reserved_gb', 2)} GB\n"
log_str += f" {'CPU Memory Used':<28}: {get_metric('perf/cpu_memory_used_gb', 2)} GB\n"
# --- Sequence Lengths ---
log_str += "\n--- 📏 Sequence Lengths ---\n"
log_str += (
f" {'Prompt Length (Mean/Max)':<28}: {get_metric('prompt/length/mean', 1)} / "
f"{get_metric('prompt/length/max', 0)}\n"
)
log_str += (
f" {'Response Length (Mean/Max)':<28}: {get_metric('response/length/mean', 1)} / "
f"{get_metric('response/length/max', 0)}\n"
)
log_str += f" {'Response Clip Ratio':<28}: {get_metric('response/clip_ratio/mean', 4)}\n"
log_str += f" {'Prompt Clip Ratio':<28}: {get_metric('prompt/clip_ratio/mean', 4)}\n"
log_str += (
f" {'Correct Resp Len (Mean/Max)':<28}: {get_metric('response/correct_length/mean', 1)} / "
f"{get_metric('response/correct_length/max', 0)}\n"
)
log_str += (
f" {'Wrong Resp Len (Mean/Max)':<28}: {get_metric('response/wrong_length/mean', 1)} / "
f"{get_metric('response/wrong_length/max', 0)}\n"
)
log_str += "\n" + "=" * 82 + "\n"
logger.info(log_str)
# ==========================================================================================
# Section 8: General Utilities
# ==========================================================================================
@staticmethod
def get_time_now(time_zone: str = "Asia/Shanghai") -> datetime:
"""Returns current time in specified timezone."""
return datetime.now(tz=ZoneInfo(time_zone))
def consistent_hash(s: str) -> int:
"""Returns consistent hash of string using MD5."""
return int(hashlib.md5(s.encode()).hexdigest(), 16)
================================================
FILE: siirl/dag_worker/dagworker.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import uuid
import ray
import torch
import asyncio
import numpy as np
import torch.distributed as dist
from collections import defaultdict
from pprint import pformat
from tqdm import tqdm
from loguru import logger
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Callable
from torch.distributed import ProcessGroup
from tensordict import TensorDict
# Handle different tensordict versions - NonTensorData location varies
try:
from tensordict import NonTensorData
except ImportError:
from tensordict.tensorclass import NonTensorData
import time
from siirl.execution.metric_worker.metric_worker import MetricClient
from siirl.models.loader import TokenizerModule, load_tokenizer
from siirl.params import SiiRLArguments
from siirl.engine.base_worker import Worker
from siirl.execution.dag import TaskGraph
from siirl.execution.dag.node import NodeRole, NodeType, Node
from siirl.execution.scheduler.reward import compute_reward, create_reward_manager
from siirl.execution.scheduler.process_group_manager import ProcessGroupManager
from siirl.execution.scheduler.enums import AdvantageEstimator, WorkflowType
from siirl.data_coordinator import preprocess_dataloader, Samples2Dict, Dict2Samples, SampleInfo
from siirl.data_coordinator.dataloader import DataLoaderNode
from siirl.dag_worker.data_structures import NodeOutput
from siirl.dag_worker.constants import DAGConstants, DAGInitializationError
from siirl.dag_worker import core_algos
from siirl.dag_worker.checkpoint_manager import CheckpointManager
from siirl.dag_worker.core_algos import (
agg_loss,
apply_kl_penalty,
compute_advantage,
compute_response_mask
)
from siirl.dag_worker.dag_utils import (
log_ray_actor_info,
get_and_validate_rank,
get_taskgraph_for_rank,
log_role_worker_mapping,
should_create_worker,
generate_node_worker_key,
find_first_non_compute_ancestor,
setup_sharding_manager,
get_worker_classes,
get_parallelism_config,
prepare_generation_batch,
format_metrics_by_group,
log_metrics_to_console,
aggregate_and_write_performance_metrics,
log_core_performance_metrics,
timer,
reduce_and_broadcast_metrics,
whether_put_data
)
from siirl.utils.debug import DistProfiler
from siirl.utils.extras.device import get_device_name, get_nccl_backend
from siirl.execution.rollout_flow.multiturn.agent_loop import AgentLoopManager
device_name = get_device_name()
class DAGWorker(Worker):
"""
Orchestrates a Directed Acyclic Graph (DAG) of tasks for distributed training,
managing the setup, initialization, and workflow for a specific rank.
"""
def __init__(
self,
config: SiiRLArguments,
process_group_manager: ProcessGroupManager,
taskgraph_mapping: Dict[int, TaskGraph],
data_coordinator: "ray.actor.ActorHandle",
metric_worker: "ray.actor.ActorHandle",
device_name="cuda",
):
super().__init__()
self.config = config
self.process_group_manager = process_group_manager
self.taskgraph_mapping = taskgraph_mapping
self.data_coordinator = data_coordinator
self.device_name = device_name
self.enable_perf = os.environ.get("SIIRL_ENABLE_PERF", "0") == "1" or config.dag.enable_perf
# State attributes
self.timing_raw = {}
self.global_steps = 0
self.total_training_steps = 0
self.workers: Dict[str, Any] = {}
self.multi_agent_group: Dict[int, Dict[NodeRole, Any]] = defaultdict(dict)
self.agent_group_process_group: Dict[int, Dict[NodeRole, Any]] = defaultdict(dict)
self.process_groups: Dict[str, ProcessGroup] = {}
self.tokenizer_mapping: Dict[str, TokenizerModule] = {}
self.logger = None
self.progress_bar = None
self._rank: int = -1
self.taskgraph: Optional[TaskGraph] = None
self.internal_data_cache: Dict[str, Any] = {}
self.sample_ref_cache: list = []
self.agent_critic_worker: Any
# Finish flag
self.taskgraph_execute_finished = False
# async rollout
self.rollout_mode = "sync"
self._async_rollout_manager = None
self.zmq_address = None # used for async_vllmrollout
# Add a cache to hold data from an insufficient batch for the next training step.
# This is the core state-carrying mechanism for dynamic sampling.
self.sampling_leftover_cache: Optional[Any] = None
# multi agent
self._multi_agent = False
# metirc_worker
self.metric_worker = MetricClient(metric_worker=metric_worker)
try:
self._initialize_worker()
except (ValueError, TypeError, KeyError, AttributeError, NotImplementedError) as e:
rank = os.environ.get("RANK", "UNKNOWN")
logger.error(f"Rank {rank}: Failed to create DAGWorker due to a critical setup error: {e}", exc_info=True)
raise DAGInitializationError(f"Initialization failed on Rank {rank}: {e}") from e
log_ray_actor_info(self._rank)
# ==========================================================================================
# Module 1: Execution and Training Loop
# ==========================================================================================
def execute_task_graph(self):
"""Main entry point to start the DAG execution pipeline."""
logger.info(f"Rank {self._rank}: Starting DAG execution pipeline...")
logger.success(f"Rank {self._rank}: All components initialized. Starting training loop from step {self.global_steps + 1}.")
if self.config.trainer.val_before_train:
self.validator.validate(global_step=self.global_steps)
self.metric_worker.wait_submit()
dist.barrier(self._gather_group)
if self._rank == 0 and self.logger:
val_metrics = self.metric_worker.wait_final_res()
logger.info(f"Initial validation metrics:\n{pformat(val_metrics)}")
self.logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.val_only:
logger.info("`val_only` is true. Halting after initial validation.")
return
self._run_training_loop()
if self.progress_bar:
self.progress_bar.close()
self.taskgraph_execute_finished = True
logger.success(f"Rank {self._rank}: DAG execution finished.")
def _run_training_loop(self):
"""
The main loop that iterates through training steps and epochs.
"""
self.total_training_steps = self.dataloader.total_training_steps
if self.dataloader.num_train_batches <= 0:
if self._rank == 0:
logger.warning(f"num_train_batches is {self.dataloader.num_train_batches}. The training loop will be skipped.")
return
if self._rank == 0:
self.progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
last_val_metrics = None
# Calculate starting epoch and batches to skip in that epoch for resumption.
start_epoch = 0
batches_to_skip = 0
if self.dataloader.num_train_batches > 0:
start_epoch = self.global_steps // self.dataloader.num_train_batches
batches_to_skip = self.global_steps % self.dataloader.num_train_batches
for epoch in range(start_epoch, self.config.trainer.total_epochs):
is_embodied = self.config.algorithm.workflow_type == WorkflowType.EMBODIED
if is_embodied:
self._cleanup_step_buffers(self.timing_raw)
for batch_idx in range(self.dataloader.num_train_batches):
if epoch == start_epoch and batch_idx < batches_to_skip:
continue
if self.global_steps >= self.total_training_steps:
logger.info(f"Rank {self._rank}: Reached total training steps. Exiting loop.")
if self._rank == 0 and last_val_metrics:
logger.info(f"Final validation metrics:\n{pformat(last_val_metrics)}")
return
if self.global_steps in self.config.profiler.profile_steps:
self._profiler.start(role="e2e", profile_step=self.global_steps)
ordered_metrics = self._run_training_step(epoch, batch_idx)
if self.global_steps in self.config.profiler.profile_steps:
self._profiler.stop()
if ordered_metrics is None:
if self.progress_bar:
self.progress_bar.update(1)
continue
self.global_steps += 1
is_last_step = self.global_steps >= self.total_training_steps
if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
self.checkpoint_manager.save_checkpoint(self.global_steps)
if self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
self.validator.validate(global_step=self.global_steps)
self.metric_worker.wait_submit()
dist.barrier(self._gather_group)
if self._rank == 0:
val_metric = self.metric_worker.wait_final_res()
ordered_metrics.update(val_metric)
if is_last_step:
last_val_metrics = val_metric
if self.enable_perf:
aggregate_and_write_performance_metrics(self._gather_group, self._rank, self.global_steps, self.config, ordered_metrics)
ordered_metric_dict = format_metrics_by_group(ordered_metrics, DAGConstants.METRIC_GROUP_ORDER)
log_core_performance_metrics(self._rank, self.enable_perf, ordered_metric_dict, self.global_steps)
if self._rank == 0:
if self.logger:
self.logger.log(data=ordered_metric_dict, step=self.global_steps)
else:
log_metrics_to_console(self._rank, ordered_metric_dict, self.global_steps)
if self.progress_bar and not (epoch == start_epoch and batch_idx < batches_to_skip):
self.progress_bar.update(1)
if self._rank == 0 and last_val_metrics:
logger.info(f"Final validation metrics:\n{pformat(last_val_metrics)}")
def _cleanup_step_buffers(self, timing_raw: dict) -> None:
"""
Encapsulates the logic for resetting and clearing all step-related buffers.
This includes the distributed Ray data buffers and the local internal cache.
This is called at the end of a step, whether it completed successfully or was aborted.
"""
# Reset the distributed (Ray) buffers for all keys that were used in this step.
with timer(self.enable_perf, "reset_data_buffer", timing_raw):
self.reset_data_buffer()
for ref in self.sample_ref_cache:
ray.internal.free(ref)
self.sample_ref_cache = []
# Clear the local, in-process cache for the next step.
with timer(self.enable_perf, "reset_intern_data_buffer", timing_raw):
self.internal_data_cache.clear()
def _run_training_step(self, epoch: int, batch_idx: int) -> Optional[List[Tuple[str, Any]]]:
"""Executes a single training step by traversing the computational graph."""
timing_raw, ordered_metrics = self.timing_raw, []
with timer(self.enable_perf, "step", timing_raw):
# --- 1. Data Loading ---
with timer(self.enable_perf, "get_data_from_dataloader", timing_raw):
is_embodied = self.config.actor_rollout_ref.model.model_type == "embodied"
repeat_n = self.config.actor_rollout_ref.rollout.n
batch = preprocess_dataloader(
self.dataloader.run(epoch=epoch, is_validation_step=False),
repeat_n
)
node_queue = self.taskgraph.get_entry_nodes()
if not node_queue:
logger.error("Taskgraph has no entry nodes. Cannot start execution.")
return None
entry_node_id = node_queue[0].node_id
# --- 2. Graph Traversal ---
visited_nodes = set()
with timer(self.enable_perf, "graph_execution", timing_raw):
while node_queue:
cur_node = node_queue.pop(0)
if cur_node.node_id in visited_nodes:
continue
visited_nodes.add(cur_node.node_id)
cur_dp_size, cur_dp_rank, cur_tp_rank, cur_tp_size, cur_pp_rank, cur_pp_size = self._get_node_dp_info(cur_node)
logger.debug(f"current node({cur_node.node_id}) dp_size: {cur_dp_size}, dp_rank: {cur_dp_rank}, tp_rank: {cur_tp_rank}, pp_rank: {cur_pp_rank}, pp_size: {cur_pp_size}")
# --- 3. Get Input Data ---
if cur_node.node_id != entry_node_id:
with timer(self.enable_perf, "get_data_from_buffer", timing_raw):
batch = self.get_data_from_buffers(key=cur_node.node_id, cur_dp_size=cur_dp_size, cur_dp_rank=cur_dp_rank, timing_raw=timing_raw)
if batch is None:
embodied_sampling = self.config.algorithm.embodied_sampling
allow_insufficient = (
self.config.algorithm.filter_groups.enable
or embodied_sampling.filter_accuracy
or embodied_sampling.filter_truncated
)
if allow_insufficient:
# Dynamic sampling scenario - waiting for data is expected behavior
if cur_node.node_role == NodeRole.ACTOR:
logger.debug(f"Rank {self._rank}: Waiting for sufficient data for node {cur_node.node_id}. Skipping this step.")
return None
else:
logger.error(f"Rank {self._rank}: Failed to get data for node {cur_node.node_id}. Skipping step.")
return None
else:
# batch = remove_prefix_from_dataproto(batch, cur_node)
logger.debug(f"current node({cur_node.node_id}) get data from databuffer batch size: {batch.size()}")
if self.enable_perf:
with timer(self.enable_perf, "get_data_from_buffer_barrier", timing_raw):
dist.barrier(self._gather_group)
# --- 4. Node Execution ---
node_name_timer = f"{cur_node.node_id}"
with timer(self.enable_perf, node_name_timer, timing_raw):
if cur_node.executable and batch is not None:
node_kwargs = {"_dag_worker_instance": self}
node_kwargs["process_group"] = self._get_node_process_group(cur_node) if cur_node.node_type != NodeType.COMPUTE else None
node_kwargs["agent_group"] = self.multi_agent_group[cur_node.agent_group]
node_kwargs["cur_tp_rank"] = cur_tp_rank
if cur_node.node_role == NodeRole.REWARD:
node_kwargs["tp_size"] = cur_tp_size
# Add parallelism info to batch for distributed reward computation
batch["dp_size"] = NonTensorData(cur_dp_size)
batch["dp_rank"] = NonTensorData(cur_dp_rank)
batch["tp_rank"] = NonTensorData(cur_tp_rank)
batch["tp_size"] = NonTensorData(cur_tp_size)
batch["pp_rank"] = NonTensorData(cur_pp_rank)
batch["pp_size"] = NonTensorData(cur_pp_size)
elif cur_node.node_role == NodeRole.ADVANTAGE:
node_kwargs["cur_node"] = cur_node
if cur_node.agent_options and cur_node.agent_options.train_cycle:
cycle_round = self.global_steps // cur_node.agent_options.train_cycle
agent_num = len(self.multi_agent_group)
if cycle_round % agent_num == cur_node.agent_group:
node_output = cur_node.run(batch=batch,
config=self.config,
**node_kwargs)
else:
node_output = NodeOutput(batch=batch)
else:
node_output = cur_node.run(batch=batch,
config=self.config,
**node_kwargs)
else:
logger.warning(f"Node {cur_node.node_id} has no executable. Passing data through.")
node_output = NodeOutput(batch=batch)
# Check if node returned empty batch (e.g., DAPO insufficient samples)
# This triggers re-rollout to collect more data
if node_output.batch is None or (node_output.batch is not None and len(node_output.batch) == 0):
logger.warning(
f"Rank {self._rank}: Node '{cur_node.node_id}' returned empty batch. "
)
embodied_sampling = self.config.algorithm.embodied_sampling
allow_insufficient = (
self.config.algorithm.filter_groups.enable
or embodied_sampling.filter_accuracy
or embodied_sampling.filter_truncated
)
if not allow_insufficient:
logger.warning(
f"Rank {self._rank}: Node '{cur_node.node_id}' returned empty batch. "
f"Aborting current step to trigger re-rollout. {node_output.batch is not None and len(node_output.batch) != 0}"
)
return None
if self.enable_perf:
with timer(self.enable_perf, f"{node_name_timer}_barrier", timing_raw):
dist.barrier(self._gather_group)
if cur_node.node_role == NodeRole.ROLLOUT and self._multi_agent:
next_nodes = self.taskgraph.get_downstream_nodes(cur_node.node_id)
while next_nodes[0].node_role == NodeRole.ROLLOUT:
cur_node = next_nodes[0]
next_nodes = self.taskgraph.get_downstream_nodes(cur_node.node_id)
# --- 5. Process Output & Get next node ---
with timer(self.enable_perf, "graph_output_handling", timing_raw):
if node_output.metrics is not None and len(node_output.metrics) > 0 and cur_tp_rank == 0 and cur_pp_rank == 0:
self.metric_worker.submit_metric(node_output.metrics, cur_dp_size)
if next_nodes := self.taskgraph.get_downstream_nodes(cur_node.node_id):
if node_output.batch is not None and len(node_output.batch) != 0:
# Currently supports single downstream node, can be extended to a loop.
next_node = next_nodes[0]
next_dp_size, _, _, _, _, _ = self._get_node_dp_info(next_node)
# node_output.batch = add_prefix_to_dataproto(node_output.batch, cur_node)
is_current_last_pp_tp_rank0 = (cur_pp_rank == cur_pp_size - 1 and cur_tp_rank == 0)
if whether_put_data(self._rank, is_current_last_pp_tp_rank0, next_dp_size, cur_dp_size, cur_node, next_node):
with timer(self.enable_perf, "put_data_to_buffer", timing_raw):
# Determine if we need to force data through DataCoordinator
# This is needed when filter causes data imbalance and requires rebalancing
embodied_sampling = self.config.algorithm.embodied_sampling
# Check if any filtering is enabled (causes data imbalance)
has_filtering = (
self.config.algorithm.filter_groups.enable
or embodied_sampling.filter_accuracy
or embodied_sampling.filter_truncated
)
# Check if current node is embodied filter node
is_embodied_filter_node = (cur_node.node_id == "embodied_sampling")
# Check if this is a COMPUTE -> consumer transition that needs rebalancing
is_compute_output = (cur_node.node_type == NodeType.COMPUTE)
needs_rebalance = (
next_node.node_type == NodeType.MODEL_TRAIN
or (is_embodied_filter_node and next_node.node_role == NodeRole.REWARD)
)
enforce_buffer = has_filtering and is_compute_output and needs_rebalance
self.put_data_to_buffers(key=next_node.node_id, data=node_output.batch, source_dp_size=cur_dp_size, dest_dp_size=next_dp_size, enforce_buffer=enforce_buffer, timing_raw=timing_raw)
# elif self._multi_agent:
# # last_node add prefix for metrics
# node_output.batch = add_prefix_to_dataproto(node_output.batch, cur_node)
if self.enable_perf:
with timer(self.enable_perf, "put_data_to_buffer_barrier", timing_raw):
dist.barrier(self._gather_group)
with timer(self.enable_perf, "get_next_node", timing_raw):
for n in next_nodes:
if n.node_id not in visited_nodes:
node_queue.append(n)
with timer(self.enable_perf, "step_barrier", timing_raw):
dist.barrier(self._gather_group)
# --- 6. Final Metrics Collection ---
self._cleanup_step_buffers(timing_raw)
ordered_metrics = {}
if cur_tp_rank == 0 and cur_pp_rank == 0:
self.metric_worker.compute_local_data_metric(batch, cur_dp_size)
self.metric_worker.compute_local_throughout_metrics(batch, timing_raw, cur_pp_size * cur_tp_size , cur_dp_size)
if self._rank == 0:
# only use rank0 time metrics
self.metric_worker.compute_local_timing_metrics(batch, timing_raw, 1)
timing_raw.clear()
self.metric_worker.wait_submit()
dist.barrier(self._gather_group)
if self._rank == 0:
metrics = self.metric_worker.wait_final_res()
ordered_metrics = dict(sorted(metrics.items()))
ordered_metrics.update({"training/global_step": self.global_steps + 1, "training/epoch": epoch + 1})
return ordered_metrics
# ==========================================================================================
# Module 2: Graph Node Execution Handlers
# ==========================================================================================
@DistProfiler.annotate(role="generate")
def generate_sync_mode(self, agent_group, batch: TensorDict) -> NodeOutput:
"""Sync mode"""
gen_output = agent_group[NodeRole.ROLLOUT].generate_sequences(batch)
if "response_mask" not in batch:
gen_output["response_mask"] = compute_response_mask(gen_output)
batch = batch.update(gen_output)
return NodeOutput(batch=batch, metrics=gen_output["metrics"])
@DistProfiler.annotate(role="generate")
def generate_async_mode(self, batch: TensorDict) -> NodeOutput:
"""Async mode"""
if self._async_rollout_manager is not None:
loop = asyncio.get_event_loop()
gen_output = loop.run_until_complete(self._async_rollout_manager.generate_sequences(batch))
metrics = gen_output["metrics"]
if "response_mask" not in batch:
batch["response_mask"] = compute_response_mask(batch)
return NodeOutput(batch=batch, metrics=metrics)
return NodeOutput(batch=batch, metrics={})
@DistProfiler.annotate(role="generate")
def generate_multi_agent_mode(self, config, batch: TensorDict) -> NodeOutput:
"""Generates sequences for a training batch using the multi-agent rollout model."""
gen_batch = prepare_generation_batch(batch)
if config.actor_rollout_ref.rollout.agent.rewards_with_env and "reward_model" in batch.non_tensor_batch:
gen_batch.non_tensor_batch["reward_model"] = batch.non_tensor_batch["reward_model"]
assert config.actor_rollout_ref.rollout.name == 'sglang'
gen_output = self.multi_agent_loop.generate_sequence(gen_batch)
if gen_output:
metrics = gen_output.meta_info.get("metrics", {})
# gen_output.meta_info = {}
# batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))])
# batch = batch.repeat(config.actor_rollout_ref.rollout.n, interleave=True).union(gen_output)
# if "response_mask" not in batch.batch:
# batch.batch["response_mask"] = compute_response_mask(batch)
return NodeOutput(batch=gen_output, metrics=metrics)
return NodeOutput(batch=batch, metrics={})
@DistProfiler.annotate(role="generate")
def generate_embodied_mode(self, agent_group, batch: TensorDict, **kwargs) -> NodeOutput:
"""
Generates embodied episodes for training.
This method follows the same pattern as _generate_for_embodied_validation in validation_mixin,
but configured for training mode (do_sample=True, validate=False).
For embodied tasks, the batch contains task metadata (task_id, trial_id, etc.) from the dataloader.
The rollout worker interacts with the environment and generates all required data
(input_ids, pixel_values, responses, etc.) during environment rollout.
Unlike text generation, we do NOT call _prepare_generation_batch because:
1. The input batch doesn't have text-generation keys (input_ids, attention_mask, etc.)
2. These keys will be generated by the embodied rollout worker during env interaction
"""
from loguru import logger
rollout_worker = agent_group[NodeRole.ROLLOUT]
rollout_n = self.config.actor_rollout_ref.rollout.n
# Set meta_info for embodied training
batch["eos_token_id"] = NonTensorData(self.validate_tokenizer.eos_token_id if self.validate_tokenizer else None)
batch["n_samples"] = NonTensorData(self.config.actor_rollout_ref.rollout.n)
batch["pad_token_id"] = NonTensorData(self.validate_tokenizer.pad_token_id if self.validate_tokenizer else None)
logger.info(
f"[Embodied Validation] Batch variables: "
f"{batch.batch_size[0]}, "
f"eos_token_id={batch['eos_token_id']}, "
f"pad_token_id={batch['pad_token_id']}, "
f"n_samples={batch['n_samples']} (dataloader already repeated {rollout_n}x), "
)
# Generate embodied episodes
gen_output = rollout_worker.generate_sequences(batch)
# Extract metrics (may be wrapped in NonTensorData)
raw_metrics = gen_output.get("metrics", {}) if hasattr(gen_output, "get") else {}
metrics = raw_metrics.data if hasattr(raw_metrics, 'data') else (raw_metrics if isinstance(raw_metrics, dict) else {})
# Merge generated data into batch
batch.update(gen_output)
# Compute response mask if not already present
if "response_mask" not in batch:
batch["response_mask"] = compute_response_mask(batch)
return NodeOutput(batch=batch, metrics=metrics)
@DistProfiler.annotate(role="generate")
def generate(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Generates sequences for a training batch using the rollout model."""
# Check if this is embodied mode
agent_group = kwargs.pop("agent_group")
is_embodied = self.config.actor_rollout_ref.model.model_type == "embodied"
if is_embodied:
# Use dedicated embodied generation path (mirrors validation logic)
return self.generate_embodied_mode(agent_group, batch, **kwargs)
if self._multi_agent is False:
if self.rollout_mode == 'sync':
return self.generate_sync_mode(agent_group, batch)
else:
return self.generate_async_mode(batch)
else:
return self.generate_multi_agent_mode(config, batch)
@DistProfiler.annotate(role="compute_reward")
def compute_reward(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Calculates rewards for a batch of generated sequences."""
from loguru import logger
if not self.check_mode() and kwargs["cur_tp_rank"] != 0:
return NodeOutput(batch=batch, metrics={})
tp_size = kwargs.pop("tp_size")
if "token_level_rewards" in batch and batch["token_level_rewards"].numel() > 0:
return NodeOutput(batch=batch, metrics={})
batch["global_token_num"] = NonTensorData((torch.sum(batch["attention_mask"], dim=-1) // tp_size).tolist())
reward_tensor, extra_infos = compute_reward(batch, self.reward_fn)
batch["token_level_scores"] = reward_tensor
if extra_infos:
batch.update({k: np.array(v) for k, v in extra_infos.items()}, inplace=True)
metrics = {}
if config.algorithm.use_kl_in_reward:
kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl_in_reward, config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch["token_level_rewards"] = batch["token_level_scores"]
return NodeOutput(batch=batch, metrics=metrics)
@DistProfiler.annotate(role="compute_old_log_prob")
def compute_old_log_prob(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Computes log probabilities from the actor model before the policy update."""
process_group = kwargs.pop("process_group")
agent_group = kwargs.pop("agent_group")
if "global_token_num" not in batch:
# in multi-agent, agentA may don't have reward node
# insert some info needed
batch["global_token_num"] = NonTensorData(torch.sum(batch["attention_mask"], dim=-1).tolist())
processed_data = agent_group[NodeRole.ACTOR].compute_log_prob(batch)
local_metrics = processed_data["metrics"] if "metrics" in processed_data else {}
if "entropys" in processed_data:
entropy = agg_loss(processed_data["entropys"], processed_data["response_mask"].to("cpu"), config.actor_rollout_ref.actor.loss_agg_mode)
local_metrics["actor/entropy_loss"] = entropy.item()
processed_data.pop("metrics", None)
processed_data.pop("entropys", None)
return NodeOutput(batch=processed_data, metrics=local_metrics)
@DistProfiler.annotate(role="compute_ref_log_prob")
def compute_ref_log_prob(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Computes log probabilities from the frozen reference model."""
agent_group = kwargs.pop("agent_group")
processed_data = agent_group[NodeRole.REFERENCE].compute_ref_log_prob(batch)
metrics = processed_data["metrics"]
return NodeOutput(batch=processed_data, metrics=metrics)
@DistProfiler.annotate(role="compute_value")
def compute_value(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Computes value estimates from the critic model."""
agent_group = kwargs.pop("agent_group")
processed_data = agent_group[NodeRole.CRITIC].compute_values(batch)
return NodeOutput(batch=processed_data)
@DistProfiler.annotate(role="compute_advantage")
def compute_multi_agent_advantage(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
adv_config = config.algorithm
rollout_config = config.actor_rollout_ref.rollout
cur_node = kwargs["cur_node"]
if "token_level_rewards" not in batch.batch :
# make sure rewards of angentB has been compute
# GAE_MARFT adv need make sure only last agent has adv node
if depend_nodes := self.taskgraph.get_dependencies(cur_node.node_id):
depend_node = depend_nodes[0]
if adv_config.share_reward_in_agent:
batch.batch["token_level_rewards"] = batch.batch[f"agent_group_{depend_node.agent_group}_token_level_rewards"].clone()
else:
batch.batch["token_level_rewards"] = torch.zeros_like(batch.batch[f"agent_group_{depend_node.agent_group}_token_level_rewards"])
batch.batch["token_level_scores"] = batch.batch[f"agent_group_{depend_node.agent_group}_token_level_scores"].clone()
else:
raise RuntimeError(f"cur_node {cur_node.node_id} have no rewards with can't find it's dependencies reward")
if adv_config.adv_estimator == AdvantageEstimator.GAE_MARFT:
# make sure adv node define in last agent node
cur_agent_id = len(self.multi_agent_group) - 1
agent_groups_ids = list(range(cur_agent_id))
kwargs["agent_group_ids"] = agent_groups_ids
# pre_agent may have no reward token
for agent_id in reversed(agent_groups_ids):
key_prefix = f"agent_group_{agent_id}_token_level_rewards"
if key_prefix not in batch.batch:
pre_key_prefix = f"agent_group_{agent_id + 1}_token_level_rewards" if agent_id != cur_agent_id -1 else "token_level_rewards"
if adv_config.share_reward_in_agent:
batch.batch[key_prefix] = batch.batch[pre_key_prefix].clone()
else:
batch.batch[key_prefix] = torch.zeros_like(batch.batch[pre_key_prefix])
batch.batch[f"agent_group_{agent_id}_token_level_scores"] = batch.batch[key_prefix].clone()
return NodeOutput(
batch=compute_advantage(
batch,
adv_estimator=adv_config.adv_estimator,
gamma=adv_config.gamma,
lam=adv_config.lam,
num_repeat=rollout_config.n,
norm_adv_by_std_in_grpo=adv_config.norm_adv_by_std_in_grpo,
weight_factor_in_cpgd=adv_config.weight_factor_in_cpgd,
multi_turn=rollout_config.multi_turn.enable,
**kwargs
)
)
@DistProfiler.annotate(role="compute_advantage")
def compute_advantage(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Computes advantages and returns for PPO using GAE."""
if not self.check_mode() and kwargs["cur_tp_rank"] != 0:
return NodeOutput(batch=batch, metrics={})
if self._multi_agent:
return self.compute_multi_agent_advantage(config, batch, **kwargs)
algo_config = config.algorithm
return NodeOutput(
batch=compute_advantage(
batch,
adv_estimator=algo_config.adv_estimator,
gamma=algo_config.gamma,
lam=algo_config.lam,
norm_adv_by_std_in_grpo=algo_config.norm_adv_by_std_in_grpo,
weight_factor_in_cpgd=algo_config.weight_factor_in_cpgd,
**kwargs
)
)
@DistProfiler.annotate(role="train_critic")
def train_critic(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Performs a single training step on the critic model."""
agent_group = kwargs.pop("agent_group")
process_group = kwargs.pop("process_group")
processed_data = agent_group[NodeRole.CRITIC].update_critic(batch)
return NodeOutput(batch=processed_data, metrics=processed_data["metrics"])
@DistProfiler.annotate(role="train_actor")
def train_actor(self, config, batch: TensorDict, **kwargs) -> NodeOutput:
"""Performs a single training step on the actor (policy) model."""
process_group = kwargs.pop("process_group")
agent_group = kwargs.pop("agent_group")
global_steps = batch["global_steps"] if "global_steps" in batch else 0
if config.trainer.critic_warmup > global_steps:
return NodeOutput(batch=batch) # Skip actor update during critic warmup
batch["multi_turn"] = NonTensorData(self.config.actor_rollout_ref.rollout.multi_turn.enable)
processed_data = agent_group[NodeRole.ACTOR].update_actor(batch)
return NodeOutput(batch=processed_data, metrics=processed_data["metrics"])
# ==========================================================================================
# Module 3: Worker and Environment Initialization
# ==========================================================================================
def _initialize_worker(self):
"""Orchestrates the ordered initialization of all worker components."""
self._rank = get_and_validate_rank()
self.taskgraph = get_taskgraph_for_rank(self._rank, self.taskgraph_mapping)
self._setup_distributed_environment()
self._setup_tokenizers()
self._setup_dataloader()
self._setup_reward_managers()
self._setup_role_worker_mapping()
self._initialize_node_workers()
self._profiler = DistProfiler(rank=self._rank, config=self.config.profiler)
# Initialize CheckpointManager - Note: will be fully initialized after workers are created
self.checkpoint_manager = None
# Initialize Validator - Note: will be initialized in init_graph() after all workers are ready
self.validator = None
# Initialize MetricsCollector - Note: will be initialized in init_graph() after all dependencies are ready
self.metrics_collector = None
if self._rank == 0:
logger.info("Rank 0: Initializing tracking logger...")
from siirl.utils.logger.tracking import Tracking
self.logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=self.config.to_dict(),
)
if self.enable_perf:
logger.warning("Performance tracking is enabled. This may impact training speed.")
def _setup_distributed_environment(self):
"""Initializes the default process group and all required subgroups."""
if not dist.is_initialized():
backend = (
f"{get_nccl_backend()}"
if self.world_size >= self.config.dag.backend_threshold
else f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}"
)
logger.info(
f"Rank {self._rank}: Initializing world size {self.world_size} default process group with '{backend}' "
f"backend."
)
dist.init_process_group(backend=backend)
if device_name == "npu":
# For NPU, metrics aggregation requires the hccl backend for device-to-device communication.
# This group is created regardless of world size for NPU environments.
gather_backend = get_nccl_backend()
self._gather_group = dist.new_group(backend=gather_backend)
else:
# For GPU, the original logic is preserved for backward compatibility.
# The gather group is only created if world_size < backend_threshold.
self._gather_group = dist.new_group(
backend="gloo") if self.world_size < self.config.dag.backend_threshold else None
group_specs = self.process_group_manager.get_all_specs()
if not group_specs:
logger.warning("No process group specifications found in ProcessGroupManager.")
return
#Builds all process groups defined in the ProcessGroupManager.
for name, spec in group_specs.items():
if not isinstance(spec, dict) or not (ranks := spec.get("ranks")):
logger.warning(f"Skipping group '{name}' due to invalid spec or missing 'ranks'.")
continue
self.process_groups[name] = dist.new_group(ranks=ranks)
logger.debug(f"Rank {self._rank}: Created {len(self.process_groups)} custom process groups.")
self.inference_group_name_set = self.process_group_manager.get_process_group_for_node_type_in_subgraph(
self.taskgraph.graph_id, NodeType.MODEL_INFERENCE.value
)
self.train_group_name_set = self.process_group_manager.get_process_group_for_node_type_in_subgraph(
self.taskgraph.graph_id, NodeType.MODEL_TRAIN.value
)
# Ensure all ranks have finished group creation before proceeding.
dist.barrier(self._gather_group)
logger.info(f"Rank {self._rank}: Distributed environment setup complete.")
def _setup_tokenizers(self):
"""Initializes and caches tokenizers for all models in the task graph."""
model_nodes = [
node
for node in self.taskgraph.nodes.values()
if node.node_type in [NodeType.MODEL_TRAIN, NodeType.MODEL_INFERENCE]
]
if not model_nodes:
logger.warning("No model nodes found in the task graph. Tokenizer setup will be skipped.")
return
for node in model_nodes:
agent_key = f"group_key_{node.agent_group}"
if agent_key not in self.tokenizer_mapping:
# Add robust check for missing configuration.
intern_config = node.config.get(DAGConstants.INTERN_CONFIG)
if not intern_config or not (model_dict := getattr(intern_config, "model", None)):
logger.warning(f"Node {node.node_id} is missing model config. Skipping tokenizer setup for it.")
continue
tokenizer_module = load_tokenizer(model_args=model_dict)
if tokenizer := tokenizer_module.get("tokenizer"):
tokenizer.padding_side = "left" # Required for most causal LM generation
self.tokenizer_mapping[agent_key] = tokenizer_module
logger.info(f"Rank {self._rank}: Initialized {len(self.tokenizer_mapping)} tokenizer(s).")
def _setup_dataloader(self):
"""Initializes the data loader for training and validation."""
rollout_nodes = [n for n in self.taskgraph.nodes.values() if n.node_type == NodeType.MODEL_INFERENCE]
if not rollout_nodes:
raise ValueError("At least one MODEL_INFERENCE node is required for dataloader setup.")
self.first_rollout_node = rollout_nodes[0]
pg_assignment = self.process_group_manager.get_node_assignment(self.first_rollout_node.node_id)
if not (process_group_name := pg_assignment.get("process_group_name")):
raise ValueError(
f"Process group name not found for the first rollout node {self.first_rollout_node.node_id}."
)
self.dataloader_process_group = self.process_groups.get(process_group_name)
if self.dataloader_process_group is None:
raise ValueError(f"Could not find process group '{process_group_name}' in the created groups.")
self.dataloader_tensor_model_parallel_size = self.first_rollout_node.config[
DAGConstants.INTERN_CONFIG
].rollout.tensor_model_parallel_size
self.dataloader = DataLoaderNode(
node_id="dataloader",
global_config=self.config,
config={
"group_world_size": dist.get_world_size(self.dataloader_process_group),
"group_rank": dist.get_rank(self.dataloader_process_group),
"group_parallel_size": self.dataloader_tensor_model_parallel_size,
"num_loader_workers": self.config.data.num_loader_workers,
"auto_repeat": self.config.data.auto_repeat,
},
)
logger.info(f"Rank {self._rank}: DataLoader initialized with {self.dataloader.total_training_steps} total training steps.")
def _setup_reward_managers(self):
"""Initializes reward managers for training and validation."""
self.validate_tokenizer = next(iter(self.tokenizer_mapping.values()), {}).get("tokenizer")
if not self.validate_tokenizer:
logger.warning("No tokenizer loaded; reward functions might fail or use a default one.")
self.reward_fn = create_reward_manager(
self.config,
self.validate_tokenizer,
num_examine=0,
max_resp_len=self.config.data.max_response_length,
overlong_buffer_cfg=self.config.reward_model.overlong_buffer,
**self.config.reward_model.reward_kwargs,
)
logger.info(f"Rank {self._rank}: Reward managers initialized.")
def _setup_role_worker_mapping(self):
"""Creates a mapping from NodeRole to the corresponding Worker implementation class."""
self.role_worker_mapping: Dict[NodeRole, Type[Worker]] = {}
# Actor/Ref/Rollout/Critic workers
actor_strategy = self.config.actor_rollout_ref.actor.strategy
self.role_worker_mapping.update(get_worker_classes(self.config, actor_strategy))
# Reward model worker (if enabled)
if self.config.reward_model.enable:
reward_strategy = self.config.reward_model.strategy
reward_workers = get_worker_classes(self.config, reward_strategy)
if NodeRole.REWARD in reward_workers:
self.role_worker_mapping[NodeRole.REWARD] = reward_workers[NodeRole.REWARD]
else:
logger.warning(
f"Reward model is enabled, but no worker found for role REWARD with strategy {reward_strategy}."
)
log_role_worker_mapping(self.role_worker_mapping)
def _initialize_node_workers(self):
"""Instantiates worker objects for all nodes in the task graph."""
for node in self.taskgraph.nodes.values():
if not should_create_worker(self.role_worker_mapping, node):
continue
worker_cls = self.role_worker_mapping.get(node.node_role)
if not worker_cls:
logger.warning(f"No worker class found for role {node.node_role.name}. Skipping node {node.node_id}.")
continue
node_worker_key = generate_node_worker_key(node)
if node_worker_key in self.workers:
continue
try:
node_process_group = self._get_node_process_group(node)
config = node.config.get(DAGConstants.INTERN_CONFIG)
if hasattr(config, "actor") and hasattr(config.actor, "optim"):
config.actor.optim.total_training_steps = self.dataloader.total_training_steps
elif hasattr(config, "optim"):
config.optim.total_training_steps = self.dataloader.total_training_steps
worker_args = {"config": config, "process_group": node_process_group}
# For separated workers (Megatron backend), no role parameter is needed
# Only legacy ActorRolloutRefWorker needs the role parameter
if hasattr(worker_cls, '__name__') and 'ActorRolloutRefWorker' in worker_cls.__name__:
if node.node_role in DAGConstants.WORKER_ROLE_MAPPING:
worker_args["role"] = DAGConstants.WORKER_ROLE_MAPPING[node.node_role]
if node.agent_options and node.agent_options.share_instance:
# cur agent share same critic with target agent
self.multi_agent_group[node.agent_group][node.node_role] = self.multi_agent_group[node.agent_options.share_instance][node.node_role]
else:
worker_instance = worker_cls(**worker_args)
self.workers[node_worker_key] = worker_instance
self.multi_agent_group[node.agent_group][node.node_role] = worker_instance
self.agent_group_process_group[node.agent_group][node.node_role] = node_process_group
logger.success(
f"Rank {self._rank}: Successfully created worker '{worker_cls.__name__}' for node: {node.node_id}"
)
except Exception as e:
# Explicitly log the failing node and worker class, then re-raise
# the exception to prevent silent failures.
logger.error(
f"Failed to create worker for node {node.node_id} with class {worker_cls.__name__}.", exc_info=True
)
raise RuntimeError(f"Worker instantiation failed for node {node.node_id}") from e
if len(self.multi_agent_group) > 1:
self._multi_agent = True
def init_graph(self):
"""
Initializes the computation graph by loading models and restoring checkpoint state.
Executed after _initialize_worker() across all workers via Ray remote call.
This method include:
(1) model weight loading,
(2) weight sharding_manager setup,
(3) async/multi-agent init,
(4) validator init,
(5) metrics collector init,
(6) checkpoint restoration
"""
self._load_model_weights()
self._setup_sharding_manager()
self._setup_async_rollout()
self._setup_multi_agent_loop()
self._init_validator()
self._init_metrics_collector()
self._init_checkpoint_manager()
self.global_steps = self.checkpoint_manager.load_checkpoint()
dist.barrier(self._gather_group)
def _load_model_weights(self):
"""Loads model weights to GPU for all node workers."""
logger.info("Loading model weights for all worker nodes...")
initialized_workers = set()
for node in self.taskgraph.nodes.values():
if not should_create_worker(self.role_worker_mapping, node):
continue
worker_key = generate_node_worker_key(node)
if worker_key in initialized_workers:
continue
node_worker = self.workers[worker_key]
if not isinstance(node_worker, Worker):
raise TypeError(f"Invalid worker type for node {node.node_id}: {type(node_worker).__name__}")
node_worker.init_model()
initialized_workers.add(worker_key)
logger.success("All model weights loaded successfully.")
def _setup_sharding_manager(self):
"""Sets up sharding managers for actor-rollout weight synchronization."""
logger.info(f"Setting up weight sharing infrastructure ({self.config.actor_rollout_ref.rollout.name})...")
for agent_group, worker_dict in self.multi_agent_group.items():
if NodeRole.ACTOR in worker_dict and NodeRole.ROLLOUT in worker_dict:
try:
setup_sharding_manager(
self.config,
self.agent_group_process_group,
agent_group,
worker_dict
)
except Exception as e:
logger.error(f"Failed to set up sharding manager for agent group {agent_group}: {e}", exc_info=True)
raise
logger.success("Weight sharing infrastructure initialized.")
def _setup_async_rollout(self):
"""Initializes async rollout server if configured."""
if self.config.actor_rollout_ref.rollout.mode != "async":
return
logger.info("Initializing async rollout server...")
for node in self.taskgraph.nodes.values():
if node.node_role == NodeRole.ROLLOUT:
self.rollout_mode = "async"
node_worker = self.workers[generate_node_worker_key(node)]
self.zmq_address = node_worker.get_zeromq_address()
self.init_async_server(node=node, node_worker=node_worker)
logger.success("Async rollout server initialized.")
def _setup_multi_agent_loop(self):
"""Initializes multi-agent loop if in multi-agent mode."""
if not self._multi_agent:
return
logger.info("Initializing multi-agent loop...")
from siirl.execution.rollout_flow.multi_agent.multiagent_generate import MultiAgentLoop
self.multi_agent_loop = MultiAgentLoop(
self,
config=self.config.actor_rollout_ref,
node_workers=self.workers,
local_dag=self.taskgraph,
databuffer=self.data_buffers,
placement_mode='colocate'
)
logger.success("Multi-agent loop initialized.")
def _init_validator(self):
"""Initializes validator for validation workflow."""
logger.info("Initializing validator...")
from siirl.dag_worker.validator import Validator
self.validator = Validator(
config=self.config,
dataloader=self.dataloader,
validate_tokenizer=self.validate_tokenizer,
multi_agent_group=self.multi_agent_group,
rollout_mode=self.rollout_mode,
async_rollout_manager=self._async_rollout_manager,
multi_agent_loop=getattr(self, 'multi_agent_loop', None),
multi_agent=self._multi_agent,
rank=self._rank,
world_size=self.world_size,
gather_group=self._gather_group,
first_rollout_node=self.first_rollout_node,
get_node_dp_info_fn=self._get_node_dp_info,
enable_perf=self.enable_perf,
metric_worker=self.metric_worker
)
logger.success("Validator initialized.")
def _init_metrics_collector(self):
"""Initializes metrics collector for training metrics aggregation."""
logger.info("Initializing metrics collector...")
# from siirl.dag_worker.metrics_collector import MetricsCollector
# self.metric_worker.init()
# self.metrics_collector = MetricsCollector(
# rank=self._rank,
# world_size=self.world_size,
# gather_group=self._gather_group,
# taskgraph=self.taskgraph,
# first_rollout_node=self.first_rollout_node,
# get_node_dp_info_fn=self._get_node_dp_info,
# multi_agent=self._multi_agent,
# enable_perf=self.enable_perf,
# )
# logger.success("Metrics collector initialized.")
def _init_checkpoint_manager(self):
"""Initializes checkpoint manager for saving/loading training state."""
logger.info("Initializing checkpoint manager...")
self.checkpoint_manager = CheckpointManager(
config=self.config,
rank=self._rank,
gather_group=self._gather_group,
workers=self.workers,
taskgraph=self.taskgraph,
dataloader=self.dataloader,
first_rollout_node=self.first_rollout_node,
get_node_dp_info_fn=self._get_node_dp_info
)
def init_async_server(self, node:Node, node_worker):
#gather zmq_address to rank_0
_, dp_rank, tp_rank, tp_size, *_ = self._get_node_dp_info(node)
addr_len = len(self.zmq_address)
encoded_addr = torch.tensor([ord(c) for c in self.zmq_address], dtype=torch.uint8,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
zmq_addresses = []
if tp_rank == 0:
group_addrs = torch.zeros((tp_size, addr_len), dtype=torch.uint8, device=encoded_addr.device)
group_addrs[0] = encoded_addr
for i in range(1, tp_size):
src_rank = dp_rank * tp_size + i
dist.recv(group_addrs[i], src=src_rank)
for i in range(tp_size):
addr_str = ''.join([chr(c.item()) for c in group_addrs[i]])
zmq_addresses.append(addr_str)
else:
dist.send(encoded_addr, dst=dp_rank * tp_size)
if tp_rank == 0:
self._async_rollout_manager = AgentLoopManager(node.config["intern_config"], dp_rank, os.environ['WG_PREFIX'], node_worker.rollout, zmq_addresses)
# ==========================================================================================
# Module 4: Utilities
# ==========================================================================================
def put_data_to_buffers(
self, key: str,
data: TensorDict,
source_dp_size:int,
dest_dp_size: int,
enforce_buffer: bool,
timing_raw: Dict[str, float]
):
"""
Puts data into the DataCoordinator by converting it into individual Samples.
The data is tagged with a 'key' to be retrieved by the correct downstream node.
"""
try:
batch_size = len(data) if data is not None else 0
if source_dp_size == dest_dp_size and not enforce_buffer:
with timer(self.enable_perf, f"put_intern_data_{key}", timing_raw):
self.internal_data_cache[key] = data
else:
samples = Dict2Samples(data)
if not samples:
logger.warning(f"Rank {self._rank}: TensorDict for key '{key}' converted to 0 samples. Nothing to put.")
return
with timer(self.enable_perf, f"put_samples_to_coordinator_{key}", timing_raw):
sample_infos = []
for sample in samples:
# Convert uid to string (handle tensor uid from postprocess_sampling)
uid_val = getattr(sample, 'uid', uuid.uuid4().int)
if isinstance(uid_val, torch.Tensor):
uid_str = str(uid_val.item()) # Works for both int and string tensors
elif hasattr(uid_val, 'tolist'):
uid_str = str(uid_val.tolist()) # Handle numpy types
else:
uid_str = str(uid_val)
sample_infos.append(SampleInfo(
sum_tokens=getattr(sample, 'sum_tokens', int(sample.attention_mask.sum())),
prompt_length=getattr(sample, 'prompt_length', 0),
response_length=getattr(sample, 'response_length', 0),
uid=uid_str,
dict_info={
'key': key,
'source_dp_size': source_dp_size # Store source DP size
}
))
# Although ray.put is called multiple times, it is more efficient than remote actor calls.
# This is the main source of the remaining overhead, but it is necessary
# to maintain sample-level traceability in the DataCoordinator.
with timer(self.enable_perf, f"ray_put_samples_{key}", timing_raw):
sample_refs = [ray.put(sample) for sample in samples]
self.sample_ref_cache.extend(sample_refs)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
caller_node_id = ray.get_runtime_context().get_node_id()
put_future = self.data_coordinator.put_batch.remote(sample_infos, sample_refs, caller_node_id)
loop.run_until_complete(put_future)
if self._rank == 0:
logger.info(f"Rank 0: PUT {len(samples)} samples to DataCoordinator for '{key}'")
except Exception as e:
logger.error(f"Rank {self._rank}: Unexpected error in put_data_to_buffers for key '{key}': {e}", exc_info=True)
raise
def get_data_from_buffers(
self,
key: str,
cur_dp_size: int,
cur_dp_rank: int,
timing_raw: Dict[str, float]
) -> Optional[TensorDict]:
"""
Gets data from the DataCoordinator by filtering for a specific key,
then collates the resulting Samples back into a single TensorDict.
Args:
key: The key to filter samples
cur_dp_size: Current node's DP size
cur_dp_rank: Current worker's DP rank
timing_raw: Timing dict for performance tracking
"""
with timer(self.enable_perf, f"get_intern_data_{key}", timing_raw):
if key in self.internal_data_cache:
cached_data = self.internal_data_cache.pop(key)
return cached_data
def key_filter(sample_info: SampleInfo) -> bool:
return sample_info.dict_info.get('key') == key
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
with timer(self.enable_perf, f"get_samples_from_coordinator_{key}", timing_raw):
try:
rollout_n = self.config.actor_rollout_ref.rollout.n if hasattr(self.config, 'actor_rollout_ref') else 1
except (AttributeError, KeyError):
rollout_n = 1
if rollout_n is None or rollout_n < 1:
rollout_n = 1
adjusted_batch_size = int(self.config.data.train_batch_size * rollout_n / cur_dp_size)
logger.debug(
f"Rank {self._rank}: Requesting from DataCoordinator: "
f"key='{key}', cur_dp={cur_dp_size}, "
f"adjusted_batch_size={adjusted_batch_size} (train_bs={self.config.data.train_batch_size} * rollout_n={rollout_n} / cur_dp={cur_dp_size})"
)
# Use filter_plugin to get only samples with matching key
# Use balance_partitions to optimize sample distribution by length
# Use cache_key to enable multi-rank caching within the same node
sample_refs = loop.run_until_complete(
self.data_coordinator.get_batch.remote(
adjusted_batch_size,
cur_dp_rank,
filter_plugin=key_filter,
balance_partitions=cur_dp_size,
cache_key=key
)
)
# Check if dynamic sampling is enabled (DAPO/embodied)
embodied_sampling = self.config.algorithm.embodied_sampling
is_dynamic_sampling = (
self.config.algorithm.filter_groups.enable
or embodied_sampling.filter_accuracy
or embodied_sampling.filter_truncated
)
if not sample_refs:
if is_dynamic_sampling:
logger.debug(f"Rank {self._rank}: Waiting for data accumulation for key '{key}' (need {adjusted_batch_size} samples)")
else:
logger.warning(f"Rank {self._rank}: DataCoordinator returned empty list for key '{key}' (adjusted_batch_size={adjusted_batch_size})")
return None
if self._rank == 0:
logger.info(f"Rank 0: GET {len(sample_refs)} samples from DataCoordinator for '{key}'")
with timer(self.enable_perf, f"ray_get_samples_{key}", timing_raw):
samples = ray.get(sample_refs)
with timer(self.enable_perf, f"collate_samples_{key}", timing_raw):
tensordict = Samples2Dict(samples)
return tensordict
def reset_data_buffer(self):
"""
DEPRECATED with DataCoordinator. The get calls are now consuming.
This can be a no-op, but for safety, we could implement a clear if needed.
For now, it does nothing as intended.
"""
logger.debug("`reset_data_buffer` is a no-op with the new DataCoordinator model as gets are consuming.")
if self._rank == 0:
self.data_coordinator.reset_cache.remote()
def _get_node_process_group(self, node: Node) -> ProcessGroup:
"""Retrieves the PyTorch ProcessGroup assigned to a specific graph node."""
assignment = self.process_group_manager.get_node_assignment(node.node_id)
if not (assignment and (name := assignment.get("process_group_name"))):
raise ValueError(f"Process group assignment or name not found for node {node.node_id}.")
pg = self.process_groups.get(name)
if pg is None:
raise ValueError(f"Process group '{name}' for node {node.node_id} was not created or found.")
return pg
def _get_node_dp_info(self, node: Node) -> tuple[int, int, int, int, int, int]:
"""
Calculates Data Parallel (DP), Tensor Parallel (TP), and Pipeline Parallel (PP) info for a node.
Returns:
tuple: (dp_size, dp_rank, tp_rank, tp_size, pp_rank, pp_size)
"""
reference_node = node
if node.node_type == NodeType.COMPUTE:
# If the node is a COMPUTE type, find its true data source ancestor.
ancestor = find_first_non_compute_ancestor(self.taskgraph, node.node_id)
if ancestor:
reference_node = ancestor
else:
# If no non-COMPUTE ancestor is found, it's a critical error.
raise RuntimeError(f"Could not find any non-COMPUTE ancestor for COMPUTE node '{node.node_id}'. Please check your DAG graph configuration.")
if reference_node.node_type == NodeType.COMPUTE:
group_world_size = self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes
group_rank = dist.get_rank()
else:
process_group = self._get_node_process_group(reference_node)
group_world_size = dist.get_world_size(process_group)
group_rank = dist.get_rank(process_group)
# Get parallelism configuration based on backend strategy
tp_size, pp_size = get_parallelism_config(reference_node)
# Calculate total parallel size (TP * PP)
total_parallel_size = tp_size * pp_size
if group_world_size % total_parallel_size != 0:
raise ValueError(f"Configuration error for node {node.node_id}: Group world size ({group_world_size}) is not divisible by total parallel size (TP={tp_size} * PP={pp_size} = {total_parallel_size}). Check your parallel configuration.")
dp_size = group_world_size // total_parallel_size
# Calculate ranks within the data parallel group
dp_rank = group_rank // total_parallel_size
# Calculate position within the TP-PP grid
local_rank_in_tp_pp_group = group_rank % total_parallel_size
# For 2D parallelism: ranks are arranged as [PP0_TP0, PP0_TP1, ..., PP0_TP(tp_size-1), PP1_TP0, ...]
pp_rank = local_rank_in_tp_pp_group // tp_size
tp_rank = local_rank_in_tp_pp_group % tp_size
return dp_size, dp_rank, tp_rank, tp_size, pp_rank, pp_size
def get_zeromq_address(self):
return self.zmq_address
def multi_agent_put_log(self, key: str, data: TensorDict, agent_group: int, next_dp_size: int, timing_raw):
# This logic needs to be adapted to the new model. For now, it's a warning.
logger.warning("`multi_agent_put_log` is not yet refactored for DataCoordinator and is a no-op.")
pass
def check_mode(self):
return self.rollout_mode == 'sync' and self._multi_agent == False
================================================
FILE: siirl/dag_worker/data_structures.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Dict
import torch
from tensordict import TensorDict
@dataclass
class ValidationResult:
"""A structured container for a single validation sample's results."""
input_text: str
output_text: str
score: float
data_source: str
reward_tensor: torch.Tensor
extra_rewards: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ValidationPayload:
"""A lightweight, serializable container for validation metrics for efficient gathering."""
input_text: str
score: float
data_source: str
extra_rewards: Dict[str, Any] = field(default_factory=dict)
@dataclass
class NodeOutput:
"""A standardized return object for all node execution functions."""
batch: TensorDict
metrics: Dict[str, Any] = field(default_factory=dict)
================================================
FILE: siirl/dag_worker/metric_aggregator.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
# Copyright 2025, Infrawaves. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as dist
from siirl.utils.extras.device import get_device_id, get_device_name
class _ReduceOp(Enum):
"""Enumeration for supported reduction operations."""
SUM = dist.ReduceOp.SUM
MAX = dist.ReduceOp.MAX
MIN = dist.ReduceOp.MIN
# Configuration for metrics that require mean, max, and min aggregation.
# Format: { "key_in_local_data": "final_metric_prefix" }
METRIC_CONFIG_FULL = {
"score": "critic/score",
"rewards": "critic/rewards",
"advantages": "critic/advantages",
"returns": "critic/returns",
"values": "critic/values",
"response_length": "response/length",
"prompt_length": "prompt/length",
"correct_response_length": "response/correct_length",
"wrong_response_length": "response/wrong_length",
}
# Configuration for metrics that only require mean aggregation.
# Format: { "key_in_local_data": "final_metric_prefix" }
METRIC_CONFIG_MEAN_ONLY = {
"response_clip_ratio": "response/clip_ratio",
"prompt_clip_ratio": "prompt/clip_ratio",
}
class DistributedMetricAggregator:
"""
A helper class to encapsulate the logic for aggregating metrics
in a distributed environment.
"""
def __init__(
self, local_metrics: Dict[str, Union[float, List[float], torch.Tensor]], group: Optional[dist.ProcessGroup]
):
"""
Initializes the aggregator and prepares metrics for reduction.
Args:
local_metrics: The dictionary of metrics on the local rank.
group: The process group for distributed communication.
"""
self.group = group
device_name = get_device_name()
if device_name in ["cuda", "npu"]:
self.device = f"{device_name}:{get_device_id()}"
else:
self.device = "cpu"
self.op_buckets = self._bucket_local_metrics(local_metrics)
def _bucket_local_metrics(self, metrics: Dict, expected_keys: set = None) -> defaultdict:
"""
Parses local metrics and groups them by the required reduction operation.
This step also performs local pre-aggregation on lists and tensors.
This version correctly handles multi-element tensors as input.
For Pipeline Parallel (PP), different stages may have different metrics.
This method ensures all ranks have the same set of keys by adding missing
metrics with default values (0.0) to avoid tensor shape mismatch in all_reduce.
Args:
metrics: Local metrics dictionary
expected_keys: Optional set of all expected metric keys across all ranks
Returns:
A defaultdict containing keys and pre-aggregated values,
grouped by reduction operation type (_ReduceOp).
"""
buckets = defaultdict(list)
# If expected_keys is provided, ensure all ranks have the same metrics
if expected_keys:
# define metrics that should be excluded from non-computing ranks
# these are training-specific metrics that only the last PP stage should contribute
training_metrics = {
'actor/pg_loss', 'actor/kl_loss', 'actor/entropy_loss', 'actor/ppo_kl',
'actor/pg_clipfrac', 'actor/pg_clipfrac_lower', 'actor/kl_coef',
'critic/vf_loss', 'critic/clipfrac'
}
# Token counting metrics should only be contributed by PP rank 0 to avoid double counting
token_counting_metrics = {
'perf/total_num_tokens/mean'
}
for key in expected_keys:
if key not in metrics:
# for training metrics: use None to indicate this rank shouldn't contribute
# for other metrics: use 0.0 as default
if any(key.startswith(prefix) for prefix in training_metrics) or key in token_counting_metrics:
# mark as None - will be handled specially in aggregation
metrics[key] = None
else:
# performance metrics get default value 0.0
metrics[key] = 0.0
for key in sorted(metrics.keys()):
value = metrics[key]
# Skip None values (training metrics from non-contributing ranks)
if value is None:
# for training metrics that this rank (those ranks that are not the last PP stage) shouldn't contribute to,
# add with count=0 so it doesn't affect the average
buckets[_ReduceOp.SUM].append((key, (0.0, 0)))
continue
# Determine if the value is a list or a tensor that needs aggregation
is_list = isinstance(value, list)
is_tensor = isinstance(value, torch.Tensor)
if "_max" in key:
op_type = _ReduceOp.MAX
if is_tensor:
# Use torch.max for tensors, get the scalar value
local_val = torch.max(value).item() if value.numel() > 0 else 0.0
elif is_list:
local_val = max(value) if value else 0.0
else: # Is a scalar float
local_val = value
buckets[op_type].append((key, local_val))
elif "_min" in key:
op_type = _ReduceOp.MIN
if is_tensor:
local_val = torch.min(value).item() if value.numel() > 0 else 0.0
elif is_list:
local_val = min(value) if value else 0.0
else:
local_val = value
buckets[op_type].append((key, local_val))
else: # Default to mean calculation (SUM operation).
op_type = _ReduceOp.SUM
if is_tensor:
local_sum = torch.sum(value).item()
local_count = value.numel()
elif is_list:
local_sum = sum(value) if value else 0.0
local_count = len(value)
else: # Is a scalar float
local_sum = value
local_count = 1
buckets[op_type].append((key, (local_sum, local_count)))
return buckets
def aggregate_and_get_results(self) -> Dict[str, float]:
"""
Performs the distributed all_reduce operations and composes the final
metrics dictionary.
Returns:
A dictionary with the globally aggregated metrics.
"""
final_metrics = {}
for op_type, data in self.op_buckets.items():
if not data:
continue
keys, values = zip(*data)
if op_type == _ReduceOp.SUM:
sums, counts = zip(*values)
sum_tensor = torch.tensor(sums, dtype=torch.float32, device=self.device)
count_tensor = torch.tensor(counts, dtype=torch.float32, device=self.device)
if self.group is not None:
dist.all_reduce(sum_tensor, op=op_type.value, group=self.group)
dist.all_reduce(count_tensor, op=op_type.value, group=self.group)
global_sums = sum_tensor.cpu().numpy()
global_counts = count_tensor.cpu().numpy()
for i, key in enumerate(keys):
final_metrics[key] = global_sums[i] / global_counts[i] if global_counts[i] > 0 else 0.0
else: # MAX or MIN operations
value_tensor = torch.tensor(values, dtype=torch.float32, device=self.device)
if self.group is not None:
dist.all_reduce(value_tensor, op=op_type.value, group=self.group)
global_values = value_tensor.cpu().numpy()
for i, key in enumerate(keys):
final_metrics[key] = global_values[i]
return final_metrics
================================================
FILE: siirl/dag_worker/metrics_collector.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MetricsCollector: Orchestrates metrics aggregation for DAG worker training steps.
This module provides the MetricsCollector class which handles the collection,
aggregation, and formatting of training metrics across distributed workers.
"""
import os
import psutil
import torch
import torch.distributed as dist
from typing import Dict, Callable, Any
from torch.distributed import ProcessGroup
from tensordict import TensorDict
# Handle different tensordict versions - NonTensorData location varies
try:
from tensordict import NonTensorData
except ImportError:
from tensordict.tensorclass import NonTensorData
from siirl.execution.dag import TaskGraph
from siirl.execution.dag.node import Node, NodeRole
from siirl.dag_worker.metric_aggregator import METRIC_CONFIG_FULL, METRIC_CONFIG_MEAN_ONLY
from siirl.dag_worker.dag_utils import (
timer,
prepare_local_batch_metrics,
reduce_and_broadcast_metrics,
remove_prefix_from_dataproto,
add_prefix_to_dataproto,
add_prefix_to_metrics,
)
from siirl.utils.extras.device import get_device_name, get_device_id
from siirl.utils.metrics.metric_utils import compute_throughout_metrics, compute_timing_metrics
from loguru import logger
class MetricsCollector:
"""
Collects and aggregates training metrics across distributed workers.
This class orchestrates the final metrics collection for each training step,
including batch metrics, performance metrics, and timing information. It uses
the existing four-layer metrics architecture:
- Layer 1: torch.distributed (communication)
- Layer 2: DistributedMetricAggregator (aggregation engine)
- Layer 3: dag_utils functions (utility layer)
- Layer 4: MetricsCollector (orchestration layer)
"""
def __init__(
self,
rank: int,
world_size: int,
gather_group: ProcessGroup,
taskgraph: TaskGraph,
first_rollout_node: Node,
get_node_dp_info_fn: Callable,
multi_agent: bool = False,
enable_perf: bool = False,
):
"""
Initialize MetricsCollector with explicit parameters.
Args:
rank: Current process rank
world_size: Total number of processes
gather_group: Process group for metric aggregation
taskgraph: DAG task graph for node iteration
first_rollout_node: Reference rollout node for configuration
get_node_dp_info_fn: Function to get node DP/TP/PP info
multi_agent: Whether in multi-agent mode
enable_perf: Whether to enable performance timing
"""
self.rank = rank
self.world_size = world_size
self.gather_group = gather_group
self.taskgraph = taskgraph
self.first_rollout_node = first_rollout_node
self.get_node_dp_info_fn = get_node_dp_info_fn
self.multi_agent = multi_agent
self.enable_perf = enable_perf
def collect_final_metrics(self, batch: TensorDict, timing_raw: dict) -> Dict[str, float]:
"""
Orchestrates the collection and computation of all metrics for a training step
using a highly efficient, all_reduce-based aggregation strategy.
This function replaces the old `compute -> reduce -> finalize` pipeline.
Args:
batch: Final batch data (TensorDict) containing all computed values
timing_raw: Dictionary of raw timing measurements
Returns:
Dictionary of aggregated metrics
"""
device_name = get_device_name()
if device_name == "cuda":
torch.cuda.reset_peak_memory_stats()
elif device_name == "npu":
torch.npu.reset_peak_memory_stats()
final_metrics = {}
# --- 1. Prepare all local metric data ---
use_critic = any(node.node_role == NodeRole.CRITIC for node in self.taskgraph.nodes.values())
local_data = prepare_local_batch_metrics(batch, use_critic=use_critic)
# --- 2. Build the dictionary for our generic, high-performance aggregator ---
# We want mean, max, and min for most standard metrics.
metrics_to_aggregate = {}
# Process metrics requiring mean, max, and min
for key, prefix in METRIC_CONFIG_FULL.items():
if key in local_data:
# The aggregator determines the operation from the key.
# We provide the same raw tensor for mean, max, and min calculations.
metrics_to_aggregate[f"{prefix}/mean"] = local_data[key]
metrics_to_aggregate[f"{prefix}_max"] = local_data[key]
metrics_to_aggregate[f"{prefix}_min"] = local_data[key]
# Process metrics requiring only mean
for key, prefix in METRIC_CONFIG_MEAN_ONLY.items():
if key in local_data:
metrics_to_aggregate[f"{prefix}/mean"] = local_data[key]
representative_actor_node = next(
(n for n in self.taskgraph.nodes.values() if n.node_role == NodeRole.ACTOR), self.first_rollout_node
)
_, _, _, _, pp_rank_in_group, _ = self.get_node_dp_info_fn(representative_actor_node)
# (1) For TP: we have already taken TP into account when we set global_token_num in compute_reward.
# see: siirl/workers/dag_worker/mixins/node_executors_mixin.py:compute_reward
# (2) For PP: only PP rank 0 contributes to avoid double counting within PP groups
# The aggregation will average across DP groups and multiply by world size to get global estimate
if pp_rank_in_group == 0:
local_token_sum = sum(batch["global_token_num"])
metrics_to_aggregate["perf/total_num_tokens/mean"] = float(local_token_sum)
# --- 3. Perform the aggregated, distributed reduction ---
with timer(self.enable_perf, "metrics_aggregation", timing_raw):
aggregated_metrics = reduce_and_broadcast_metrics(metrics_to_aggregate, self.gather_group)
# Post-process keys and values for the final output
for key, value in aggregated_metrics.items():
if "_max" in key and "mem" not in key:
final_metrics[key.replace("_max", "/max")] = value
elif "_min" in key:
final_metrics[key.replace("_min", "/min")] = value
else:
final_metrics[key] = value
# Special handling for total_num_tokens to convert mean back to sum
if "perf/total_num_tokens/mean" in final_metrics:
final_metrics["perf/total_num_tokens"] = final_metrics.pop(
"perf/total_num_tokens/mean"
) * dist.get_world_size(self.gather_group)
# --- 4. Handle special cases like Explained Variance ---
if use_critic:
# Determine the correct device for distributed operations
device_name = get_device_name()
if device_name in ["cuda", "npu"]:
device = f"{device_name}:{get_device_id()}"
else:
# Fallback to the device of an existing tensor. If it's CPU, all_reduce will fail,
# which is the original problem, indicating a deeper issue.
device = local_data["returns"].device
# These components only need to be summed. We can do a direct all_reduce.
components_to_sum = {k: v for k, v in local_data.items() if k.endswith("_comp")}
for tensor in components_to_sum.values():
if self.gather_group is not None:
dist.all_reduce(tensor.to(device), op=dist.ReduceOp.SUM, group=self.gather_group)
# Now all ranks have the global sums and can compute the final value.
N = local_data["returns"].numel()
total_N_tensor = torch.tensor([N], dtype=torch.int64, device=local_data["returns"].device)
if self.gather_group is not None:
dist.all_reduce(total_N_tensor.to(device), op=dist.ReduceOp.SUM, group=self.gather_group)
global_N = total_N_tensor.item()
if global_N > 0:
global_returns_sum = final_metrics["critic/returns/mean"] * global_N
global_returns_sq_sum = components_to_sum["returns_sq_sum_comp"].item()
global_error_sum = components_to_sum["error_sum_comp"].item()
global_error_sq_sum = components_to_sum["error_sq_sum_comp"].item()
mean_returns = global_returns_sum / global_N
var_returns = (global_returns_sq_sum / global_N) - (mean_returns**2)
mean_error = global_error_sum / global_N
var_error = (global_error_sq_sum / global_N) - (mean_error**2)
final_metrics["critic/vf_explained_var"] = 1.0 - var_error / (var_returns + 1e-8)
else:
final_metrics["critic/vf_explained_var"] = 0.0
# --- 5. Add timing and other rank-0-only metrics ---
# Only rank 0 needs to compute these for logging.
if self.rank == 0:
batch["global_token_num"] = NonTensorData([final_metrics.get("perf/total_num_tokens", 0)])
final_metrics.update(compute_throughout_metrics(batch, timing_raw, dist.get_world_size()))
final_metrics["perf/process_cpu_mem_used_gb"] = psutil.Process(os.getpid()).memory_info().rss / (1024**3)
timing_metrics = compute_timing_metrics(batch, timing_raw)
for key, value in timing_metrics.items():
if key.startswith("timing_s/"):
final_metrics[key.replace("timing_s/", "perf/delta_time/")] = value
# Calculate rollout and actor log probs difference statistics
if "rollout_log_probs" in batch and "old_log_probs" in batch:
rollout_probs = torch.exp(batch["rollout_log_probs"])
actor_probs = torch.exp(batch["old_log_probs"])
rollout_probs_diff = torch.masked_select(
torch.abs(rollout_probs.cpu() - actor_probs),
batch["response_mask"].bool().cpu()
)
if rollout_probs_diff.numel() > 0:
final_metrics.update({
"training/rollout_probs_diff_max": torch.max(rollout_probs_diff).item(),
"training/rollout_probs_diff_mean": torch.mean(rollout_probs_diff).item(),
"training/rollout_probs_diff_std": torch.std(rollout_probs_diff).item()
})
# All ranks return the final metrics. Ranks other than 0 can use them if needed,
# or just ignore them. This is cleaner than returning an empty dict.
return final_metrics
def collect_multi_agent_final_metrics(self, batch: TensorDict, ordered_metrics: list, timing_raw: dict) -> list:
"""
Collects final metrics for multi-agent mode by iterating through agent rollout nodes.
Args:
batch: Final batch data (TensorDict) with multi-agent prefixes
ordered_metrics: List of (key, value) tuples to extend
timing_raw: Dictionary of raw timing measurements
Returns:
Extended list of ordered metrics
"""
node_queue = self.taskgraph.get_entry_nodes()
visited_nodes = set()
while node_queue:
cur_node = node_queue.pop(0)
if cur_node.node_id in visited_nodes:
continue
if cur_node.node_role != NodeRole.ROLLOUT:
break
batch = remove_prefix_from_dataproto(batch, cur_node)
final_metrics = self.collect_final_metrics(batch, timing_raw)
final_metrics = add_prefix_to_metrics(final_metrics, cur_node)
if final_metrics:
ordered_metrics.extend(sorted(final_metrics.items()))
if next_nodes := self.taskgraph.get_downstream_nodes(cur_node.node_id):
for n in next_nodes:
if n.node_id not in visited_nodes:
node_queue.append(n)
batch = add_prefix_to_dataproto(batch, cur_node)
return ordered_metrics
================================================
FILE: siirl/dag_worker/validator.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import asyncio
import torch
import numpy as np
import torch.distributed as dist
from ray.actor import ActorHandle
from collections import defaultdict
from typing import Dict, List, Callable, Any, Optional, Tuple
from loguru import logger
from tensordict import TensorDict
# Handle different tensordict versions - NonTensorData location varies
try:
from tensordict import NonTensorData
except ImportError:
from tensordict.tensorclass import NonTensorData
from torch.distributed import ProcessGroup
from siirl.data_coordinator import preprocess_dataloader
from siirl.data_coordinator.dataloader import DataLoaderNode
from siirl.execution.dag.node import NodeRole, Node
from siirl.execution.scheduler.reward import create_reward_manager
from siirl.dag_worker.data_structures import ValidationPayload, ValidationResult
from siirl.dag_worker.dag_utils import dump_validation_generations, timer
from siirl.utils.metrics.metric_utils import aggregate_validation_metrics
from siirl.params import SiiRLArguments
class Validator:
"""
Handles the complete validation workflow for distributed RL training.
This class orchestrates the validation process including:
- Batch preparation and generation
- Reward scoring and result packaging
- Metrics aggregation across all ranks
- Performance logging
The validator operates in a distributed manner, coordinating across multiple
ranks and aggregating results on rank 0.
"""
def __init__(
self,
config: SiiRLArguments,
dataloader: DataLoaderNode,
validate_tokenizer: Any,
multi_agent_group: Dict[int, Dict[NodeRole, Any]],
rollout_mode: str,
async_rollout_manager: Optional[Any],
multi_agent_loop: Optional[Any],
multi_agent: bool,
rank: int,
world_size: int,
gather_group: ProcessGroup,
first_rollout_node: Node,
get_node_dp_info_fn: Callable,
enable_perf: bool = False,
metric_worker: ActorHandle = None
):
"""
Initialize the Validator with explicit dependencies.
Args:
config: Training configuration
dataloader: Data loading utilities
val_reward_fn: Validation reward function
validate_tokenizer: Tokenizer for decoding sequences
multi_agent_group: Worker groups for generation (indexed by agent_group -> role)
rollout_mode: Generation mode ('sync' or 'async')
async_rollout_manager: Manager for async rollout (None if not using async)
multi_agent_loop: Manager for multi-agent generation (None if not multi-agent)
multi_agent: Whether in multi-agent mode
rank: Current process rank
world_size: Total number of processes
gather_group: Process group for distributed gathering
first_rollout_node: First rollout node for getting DP/TP/PP info
get_node_dp_info_fn: Function to get node parallelism info
enable_perf: Whether to enable performance profiling
"""
self.config = config
self.dataloader = dataloader
self.validate_tokenizer = validate_tokenizer
self.multi_agent_group = multi_agent_group
self.rollout_mode = rollout_mode
self.async_rollout_manager = async_rollout_manager
self.multi_agent_loop = multi_agent_loop
self.multi_agent = multi_agent
self.rank = rank
self.world_size = world_size
self.gather_group = gather_group
self.first_rollout_node = first_rollout_node
self.get_node_dp_info_fn = get_node_dp_info_fn
self.enable_perf = enable_perf
self.val_reward_fn = create_reward_manager(
self.config,
self.validate_tokenizer,
num_examine=1,
max_resp_len=self.config.data.max_response_length,
overlong_buffer_cfg=self.config.reward_model.overlong_buffer,
)
# Validation timing tracking
self.val_timedict = defaultdict(float)
self.metric_worker = metric_worker
def validate(self, global_step: int) -> Dict[str, float]:
"""Performs validation based on dataset type."""
# Correctly use the existing dataset_type parameter from the data config.
dataset_type = getattr(self.config.data, "dataset_type", "llm")
if dataset_type == "embodied":
return self._validate_embodied(global_step)
else:
return self._validate_text_generation(global_step)
def _validate_embodied(self, global_step) -> Dict[str, float]:
"""
Performs embodied validation by running interactive episodes.
This is the main entry point that orchestrates the entire validation flow:
1. Initialize timers and check prerequisites
2. Iterate through validation batches (each rank processes a shard)
3. Generate embodied episodes via rollout worker
4. Score results using val_reward_fn
5. Gather payloads from all ranks to rank 0
6. Aggregate and return final metrics
Returns:
Dict[str, float]: Validation metrics (only on rank 0, empty dict on other ranks)
"""
# 1. Initialize timers
self.timers = defaultdict(float)
if self.rank == 0:
logger.info("=" * 60)
logger.info(f"Starting Embodied Validation @ Global Step {global_step}...")
logger.info("=" * 60)
self.timers["overall_start_time"] = time.perf_counter()
# 2. Check if num_val_batches > 0 to avoid unnecessary loops
if self.dataloader.num_val_batches <= 0:
if self.rank == 0:
logger.warning("num_val_batches is 0. Skipping embodied validation.")
return {}
# 3. Collect payloads from all batches
all_payloads = []
for i in range(self.dataloader.num_val_batches):
if self.rank == 0:
logger.debug(f"Processing embodied validation batch {i + 1}/{self.dataloader.num_val_batches}")
# 3.1 Prepare and generate
with timer(self.enable_perf, "prep_and_generate", self.val_timedict):
batch_proto = self._prepare_embodied_validation_batch()
generated_proto = self._generate_for_embodied_validation(batch_proto, global_step)
dist.barrier(self.gather_group)
# 3.2 Score
with timer(self.enable_perf, "score", self.val_timedict):
batch_payloads = self._score_embodied_results(generated_proto)
all_payloads.extend(batch_payloads)
# 4. Gather payloads to rank 0 (only TP/PP master ranks prepare payload)
dp_size, _, tp_rank, _, pp_rank, _ = self.get_node_dp_info_fn(self.first_rollout_node)
with timer(self.enable_perf, "gather_payloads", self.val_timedict):
if tp_rank == 0 and pp_rank == 0:
self.metric_worker.submit_metric(self._aggregate_and_log_embodied_metrics(all_payloads, global_step), dp_size)
# with timer(self.enable_perf, "gather_payloads", self.val_timedict):
# payloads_for_metrics = []
# if tp_rank == 0 and pp_rank == 0:
# payloads_for_metrics = all_payloads
# gathered_payloads_on_rank0 = [None] * self.world_size if self._rank == 0 else None
# dist.gather_object(payloads_for_metrics, gathered_payloads_on_rank0, dst=0, group=self._gather_group)
# # 5. Rank 0 aggregates and logs metrics
# if self.rank == 0:
# flat_payload_list = [p for sublist in gathered_payloads_on_rank0 if sublist for p in sublist]
# final_metrics = self._aggregate_and_log_embodied_metrics(flat_payload_list)
dist.barrier(self.gather_group)
return
def _validate_text_generation(self, global_step: int) -> Dict[str, float]:
"""
Executes the complete validation workflow.
This is the main entry point for validation. It:
1. Prepares validation batches from the dataloader
2. Generates sequences using the rollout model
3. Scores the generated sequences
4. Aggregates metrics across all ranks
5. Logs performance breakdown (on rank 0)
Args:
global_step: Current training step (for logging and checkpointing)
Returns:
Dict[str, float]: Validation metrics (only on rank 0, empty dict on other ranks)
"""
self.val_timedict = defaultdict(float)
if self.rank == 0:
logger.info("=" * 60)
logger.info(f"Starting Validation @ Global Step {global_step}...")
logger.info("=" * 60)
self.val_timedict["overall_start_time"] = time.perf_counter()
all_scored_results: List[ValidationResult] = []
# Check if num_val_batches > 0 to avoid unnecessary loops.
if self.dataloader.num_val_batches <= 0:
if self.rank == 0:
logger.warning("num_val_batches is 0. Skipping validation.")
return {}
sample_turns = []
for i in range(self.dataloader.num_val_batches):
if self.rank == 0:
logger.debug(f"Processing validation batch {i + 1}/{self.dataloader.num_val_batches}")
with timer(self.enable_perf, "prep_and_generate", self.val_timedict):
test_batch = self.dataloader.run(is_validation_step=True)
val_batch = preprocess_dataloader(test_batch, self.config.actor_rollout_ref.rollout.val_kwargs.n)
generated_proto = self._generate_for_validation(val_batch)
dist.barrier(self.gather_group)
with timer(self.enable_perf, "score_and_package", self.val_timedict):
scored_results = self._score_and_package_results(generated_proto)
all_scored_results.extend(scored_results)
dump_validation_generations(self.config, global_step, self.rank, all_scored_results)
dist.barrier(self.gather_group)
dp_size, _, tp_rank, _, pp_rank, _ = self.get_node_dp_info_fn(self.first_rollout_node)
# Gather all payloads to rank 0
with timer(self.enable_perf, "gather_payloads", self.val_timedict):
if tp_rank == 0 and pp_rank == 0:
# Only the master rank of the TP group (tp_rank=0) and first PP stage (pp_rank=0) prepares the payload.
payloads_for_metrics = [
ValidationPayload(r.input_text, r.score, r.data_source, r.extra_rewards) for r in all_scored_results
]
self.metric_worker.submit_metric(self._aggregate_and_log_validation_metrics(payloads_for_metrics), dp_size)
dist.barrier(self.gather_group)
return
def _generate_for_validation(self, batch: TensorDict) -> TensorDict:
"""
Generates sequences using the rollout worker for a validation batch.
Supports three generation modes:
- Sync mode: Direct generation via rollout worker
- Async mode: Generation via async rollout manager
- Multi-agent mode: Generation via multi-agent loop
Args:
batch_proto: Input batch containing prompts
Returns:
TensorDict: Batch with generated sequences added
"""
rollout_worker = self.multi_agent_group[0][NodeRole.ROLLOUT]
val_kwargs = self.config.actor_rollout_ref.rollout.val_kwargs
prompt_texts = self.validate_tokenizer.batch_decode(
batch["input_ids"], skip_special_tokens=True
)
batch["prompt_texts"] = prompt_texts
batch["eos_token_id"] = NonTensorData(self.validate_tokenizer.eos_token_id)
batch["pad_token_id"] = NonTensorData(self.validate_tokenizer.eos_token_id)
batch["recompute_log_prob"] = NonTensorData(self.validate_tokenizer.eos_token_id)
batch["validate"] = NonTensorData(True)
batch["do_sample"] = NonTensorData(val_kwargs.do_sample)
output = None
if self.multi_agent is False:
if self.rollout_mode == 'sync':
output = rollout_worker.generate_sequences(batch)
elif self.async_rollout_manager:
loop = asyncio.get_event_loop()
output = loop.run_until_complete(self.async_rollout_manager.generate_sequences(batch))
else:
output = self.multi_agent_loop.generate_sequence(batch)
if output is not None:
return output
return batch
def _score_and_package_results(self, generated_proto: TensorDict) -> List[ValidationResult]:
"""
Scores generated sequences and packages them into ValidationResult objects.
This method:
1. Computes rewards for generated sequences (or uses pre-computed rewards)
2. Decodes input prompts and output responses
3. Packages everything into ValidationResult objects
4. Filters out padded duplicates for trailing ranks
Args:
generated_proto: Batch containing generated sequences
Returns:
List[ValidationResult]: Scored and packaged validation results
Dict: extra_rewards
"""
if self.rollout_mode == 'async' and self.async_rollout_manager is None:
return []
if self.multi_agent and 'responses' not in generated_proto:
return []
if "token_level_rewards" in generated_proto:
reward_result = {"reward_tensor": generated_proto["token_level_rewards"],
"reward_extra_info": {}}
else:
reward_result = self.val_reward_fn(generated_proto, return_dict=True)
scores = reward_result["reward_tensor"].sum(-1).cpu()
input_texts = generated_proto["prompt_texts"] if "prompt_texts" in generated_proto else None
if input_texts is None:
logger.error(
"FATAL: `prompt_texts` not found in `non_tensor_batch`. "
"The prompt data was lost during the process. Falling back to decoding the full sequence, "
"but please be aware the resulting `input_text` will be INCORRECT (it will contain prompt + response)."
)
# Fallback to prevent a crash, but the output is known to be wrong.
input_texts = self.validate_tokenizer.batch_decode(
generated_proto["input_ids"], skip_special_tokens=True
)
output_texts = self.validate_tokenizer.batch_decode(generated_proto["responses"], skip_special_tokens=True)
data_sources = generated_proto["data_source"] if "data_source" in generated_proto else ["unknown"] * len(scores)
extra_info = generated_proto["extra_info"] if "data_source" in generated_proto else [None] * len(scores)
packaged_results = []
for i in range(len(scores)):
if self.dataloader.is_val_trailing_rank and isinstance(extra_info[i], dict) and extra_info[i].get("padded_duplicate", None):
logger.debug(f"Rank {self.rank} skip append padded duplicate item {i}: score={scores[i].item()}")
continue
extra_rewards = {k: v[i] for k, v in reward_result.get("reward_extra_info", {}).items()}
packaged_results.append(ValidationResult(input_texts[i], output_texts[i], scores[i].item(), data_sources[i], reward_result["reward_tensor"][i], extra_rewards))
return packaged_results
def _aggregate_and_log_validation_metrics(self, all_payloads: List[ValidationPayload]) -> Dict[str, float]:
"""
Aggregates all validation results and logs performance (rank 0 only).
This method:
1. Calls _aggregate_validation_results to compute final metrics
2. Logs a detailed performance breakdown of the validation process
3. Reports total validation time
Args:
all_payloads: All validation payloads gathered from all ranks
Returns:
Dict[str, float]: Final aggregated validation metrics
"""
if not all_payloads:
logger.warning("Validation finished with no results gathered on Rank 0 to aggregate.")
return {}
with timer(self.enable_perf, "final_aggregation", self.val_timedict):
final_metrics = self._aggregate_validation_results(all_payloads)
# Log performance breakdown
total_time = time.perf_counter() - self.val_timedict.pop("overall_start_time", time.perf_counter())
if self.rank == 0:
logger.info(f"Rank 0: Aggregating {len(all_payloads)} validation results...")
logger.info("--- Validation Performance Breakdown (Rank 0) ---")
for name, duration in self.val_timedict.items():
logger.info(f" Total {name.replace('_', ' ').title():<25}: {duration:.4f}s")
known_time = sum(self.val_timedict.values())
logger.info(f" {'Other/Overhead':<25}: {max(0, total_time - known_time):.4f}s")
logger.info(f" {'TOTAL VALIDATION TIME':<25}: {total_time:.4f}s")
logger.info("=" * 51)
return final_metrics
def _aggregate_validation_results(self, all_payloads: List[ValidationPayload]) -> Dict[str, float]:
"""
Computes the final metric dictionary from all gathered validation payloads.
This method processes validation results to compute:
- Mean/majority/best metrics for different data sources
- Pass@N accuracy metrics
- Per-data-source test scores
Args:
all_payloads: All validation payloads from all ranks
Returns:
Dict[str, float]: Final validation metrics organized by data source and metric type
"""
data_sources = [p.data_source for p in all_payloads]
sample_inputs = [p.input_text for p in all_payloads]
infos_dict = defaultdict(list)
for p in all_payloads:
infos_dict["reward"].append(p.score)
for key, value in p.extra_rewards.items():
infos_dict[key].append(value)
data_src2var2metric2val = aggregate_validation_metrics(data_sources=data_sources, sample_inputs=sample_inputs, infos_dict=infos_dict)
metric_dict = {}
for data_source, var2metric2val in data_src2var2metric2val.items():
core_var = "acc" if "acc" in var2metric2val else "reward"
for var_name, metric2val in var2metric2val.items():
if not metric2val:
continue
# Robustly parse '@N' to prevent crashes from malformed metric names.
n_max_values = []
for name in metric2val.keys():
if "@" in name and "/mean" in name:
try:
n_val = int(name.split("@")[-1].split("/")[0])
n_max_values.append(n_val)
except (ValueError, IndexError):
continue # Ignore malformed metric names
n_max = max(n_max_values) if n_max_values else 1
for metric_name, metric_val in metric2val.items():
is_core_metric = (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}" in metric_name)
metric_sec = "val-core" if is_core_metric else "val-aux"
pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
metric_dict[pfx] = metric_val
# Re-calculate test_score per data source
data_source_rewards = defaultdict(list)
for p in all_payloads:
data_source_rewards[p.data_source].append(p.score)
for source, rewards in data_source_rewards.items():
if rewards:
metric_dict[f"val/test_score/{source}"] = np.mean(rewards)
return metric_dict
def _prepare_embodied_validation_batch(self) -> TensorDict:
"""
Fetches and prepares a single embodied validation batch.
Unlike text-generation validation, embodied validation does NOT repeat batches
because running embodied episodes is expensive.
Returns:
TensorDict: The validation batch
"""
test_batch = self.dataloader.run(is_validation_step=True)
test_batch_proto = preprocess_dataloader(test_batch)
# No repeat for embodied validation (confirmed by user)
return test_batch_proto
def _generate_for_embodied_validation(self, batch: TensorDict, global_step:int) -> TensorDict:
"""
Generates embodied episodes using the rollout worker.
Sets up meta_info for validation mode (validate=True, do_sample=False) and
calls the appropriate generation method based on rollout configuration.
Args:
batch: The input batch containing task information
Returns:
TensorDict: The batch with generated episode data (actions, observations, rewards, etc.)
"""
rollout_worker = self.multi_agent_group[0][NodeRole.ROLLOUT]
# Set meta_info for embodied validation
batch["eos_token_id"] = NonTensorData(self.validate_tokenizer.eos_token_id)
batch["pad_token_id"] = NonTensorData(self.validate_tokenizer.pad_token_id)
batch["recompute_log_prob"] = NonTensorData(False)
batch["validate"] = NonTensorData(True)
batch["do_sample"] = NonTensorData(False)
batch["global_steps"] = NonTensorData(global_step)
logger.info(
f"[Embodied Validation] Batch variables: "
f"eos_token_id={batch['eos_token_id']}, "
f"pad_token_id={batch['pad_token_id']}, "
f"recompute_log_prob={batch['recompute_log_prob']}, "
f"validate={batch['validate']}, "
f"do_sample={batch['do_sample']}, "
f"global_steps={batch['global_steps']}"
)
# Generate episodes based on rollout mode
output = None
output = rollout_worker.generate_sequences(batch)
# Union the output with the original batch
if output is not None:
return output
return batch
def _score_embodied_results(self, generated_proto: TensorDict) -> List[ValidationPayload]:
"""
Scores generated embodied episodes using val_reward_fn and packages lightweight payloads.
Unlike text-generation, embodied validation:
- Uses val_reward_fn.verify() instead of val_reward_fn()
- Returns (verifier_score, reward_metrics, format_metrics, reward_format_metrics)
- Doesn't need to decode text prompts/responses
Args:
generated_proto: The batch with generated episode data
Returns:
List[ValidationPayload]: Lightweight payloads for gathering
"""
if self.val_reward_fn:
verifier_score, reward_metrics, format_metrics, reward_format_metrics = self.val_reward_fn.verify(generated_proto)
reward_tensor = torch.tensor(verifier_score, dtype=torch.float32).unsqueeze(-1)
# Store batch-level metrics (without prefix, will be added during aggregation)
batch_metrics = {
'reward_metrics': reward_metrics,
'format_metrics': format_metrics,
'reward_format_metrics': reward_format_metrics,
}
# 3. Get data sources (task suite name)
task_suite_name = getattr(self.config.actor_rollout_ref.embodied.env, 'env_name', 'unknown_task')
data_sources = generated_proto.get("data_source", [task_suite_name] * reward_tensor.shape[0])
# 4. Get input identifiers (for debugging/logging)
# For embodied tasks, we use task_file_name if available
if "task_file_name" in generated_proto:
task_file_names_bytes = generated_proto["task_file_name"].cpu().numpy()
input_texts = []
for tfn_bytes in task_file_names_bytes:
# Decode bytes to string
tfn_str = bytes(tfn_bytes).decode('utf-8').rstrip('\x00')
input_texts.append(f"Task: {tfn_str}")
else:
# Fallback: use generic episode identifiers
input_texts = [f"Episode_{i}" for i in range(reward_tensor.shape[0])]
# 5. Compute scores
scores = reward_tensor.sum(-1).cpu()
# 6. Package payloads (lightweight, no full reward tensor)
# Only the first sample in each batch carries the batch_metrics to avoid duplication
packaged_payloads = []
for i in range(len(scores)):
payload = ValidationPayload(
input_text=input_texts[i],
score=scores[i].item(),
data_source=data_sources[i],
extra_rewards=batch_metrics if i == 0 else {} # Only first sample carries batch metrics
)
packaged_payloads.append(payload)
return packaged_payloads
def _aggregate_and_log_embodied_metrics(self, all_payloads: List[ValidationPayload], global_step: int) -> Dict[str, float]:
"""
On Rank 0, aggregates all embodied validation results and logs performance.
This function runs only on rank 0 after gathering payloads from all ranks.
Args:
all_payloads: All ValidationPayload objects gathered from all ranks
Returns:
Dict[str, float]: Final validation metrics
"""
if not all_payloads:
logger.warning("Embodied validation finished with no results gathered on Rank 0.")
return {}
logger.info(f"Rank 0: Aggregating {len(all_payloads)} embodied validation results...")
# 1. Aggregate results
with timer(self.enable_perf, "final_aggregation", self.val_timedict):
final_metrics = self._aggregate_embodied_results(all_payloads)
# 2. Log performance breakdown
if self.rank == 0:
total_time = time.perf_counter() - self.timers.pop("overall_start_time", time.perf_counter())
logger.info("=" * 60)
logger.info("--- Embodied Validation Performance Breakdown (Rank 0) ---")
for name, duration in self.timers.items():
logger.info(f" Total {name.replace('_', ' ').title():<30}: {duration:.4f}s")
known_time = sum(self.timers.values())
logger.info(f" {'Other/Overhead':<30}: {max(0, total_time - known_time):.4f}s")
logger.info(f" {'TOTAL EMBODIED VALIDATION TIME':<30}: {total_time:.4f}s")
logger.info("=" * 60)
# 3. Log final metrics
logger.info(f"Embodied Validation Metrics (Global Step {global_step}):")
for metric_name, metric_value in sorted(final_metrics.items()):
logger.info(f" {metric_name}: {metric_value:.4f}")
logger.info("=" * 60)
return final_metrics
def _aggregate_embodied_results(self, all_payloads: List[ValidationPayload]) -> Dict[str, float]:
"""
Computes the final metric dictionary from all gathered embodied validation payloads.
Aggregation strategy:
1. Group rewards by data_source
2. Compute mean reward per data_source and overall
3. Collect and average batch-level metrics from all batches
Args:
all_payloads: All ValidationPayload objects from all ranks
Returns:
Dict[str, float]: Final metrics with structure:
- embodied/test_score/{data_source}: Mean reward per data source
- embodied/test_score/all: Overall mean reward
- embodied/reward/{metric_name}: Averaged reward metrics across all batches
- embodied/format/{metric_name}: Averaged format metrics across all batches
- embodied/reward_format/{metric_name}: Averaged combined metrics across all batches
"""
# 1. Group rewards by data_source
data_source_rewards = {}
for payload in all_payloads:
data_source = payload.data_source
if data_source not in data_source_rewards:
data_source_rewards[data_source] = []
data_source_rewards[data_source].append(payload.score)
# 2. Compute per-data-source metrics
metric_dict = {}
for data_source, rewards in data_source_rewards.items():
metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards)
# 3. Compute overall mean reward
all_rewards = [p.score for p in all_payloads]
metric_dict['val/test_score/all'] = np.mean(all_rewards)
# 4. Collect and average batch-level metrics from all batches
batch_metrics_list = []
for payload in all_payloads:
if payload.extra_rewards and 'reward_metrics' in payload.extra_rewards:
batch_metrics_list.append(payload.extra_rewards)
if batch_metrics_list:
num_batches = len(batch_metrics_list)
logger.info(f"[_aggregate_embodied_results] Collected metrics from {num_batches} batches")
# 4.1 Collect all reward_metrics (including 'all' and per-task metrics)
aggregated_reward_metrics = {}
for batch_metrics in batch_metrics_list:
for key, value in batch_metrics['reward_metrics'].items():
if key not in aggregated_reward_metrics:
aggregated_reward_metrics[key] = []
aggregated_reward_metrics[key].append(value)
# 4.2 Collect all format_metrics
aggregated_format_metrics = {}
for batch_metrics in batch_metrics_list:
for key, value in batch_metrics['format_metrics'].items():
if key not in aggregated_format_metrics:
aggregated_format_metrics[key] = []
aggregated_format_metrics[key].append(value)
# 4.3 Collect all reward_format_metrics
aggregated_reward_format_metrics = {}
for batch_metrics in batch_metrics_list:
for key, value in batch_metrics['reward_format_metrics'].items():
if key not in aggregated_reward_format_metrics:
aggregated_reward_format_metrics[key] = []
aggregated_reward_format_metrics[key].append(value)
# 5. Compute average values and add to metric_dict
for key, values in aggregated_reward_metrics.items():
metric_dict[f'embodied/reward/{key}'] = np.mean(values)
for key, values in aggregated_format_metrics.items():
metric_dict[f'embodied/format/{key}'] = np.mean(values)
for key, values in aggregated_reward_format_metrics.items():
metric_dict[f'embodied/reward_format/{key}'] = np.mean(values)
return metric_dict
================================================
FILE: siirl/data_coordinator/__init__.py
================================================
# Copyright (c) 2025, Shanghai Innovation Institute. All rights reserved.
from .protocol import *
from .data_buffer import *
from .sample import *
================================================
FILE: siirl/data_coordinator/data_buffer.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Dict, List, Optional, Tuple, Callable, Any
import heapq
import random
import ray
import loguru
import time
from collections import deque, defaultdict
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from siirl.data_coordinator.sample import SampleInfo
from siirl.utils.model_utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions
@ray.remote
class DataCoordinator:
"""
A globally unique central Actor responsible for coordinating data producers (RolloutWorkers)
and consumers (Trainers). It does not store the actual sample data, only the sample
metadata (SampleInfo) and object references (ObjectRef). This allows it to implement
complex global sampling strategies at a very low cost.
"""
def __init__(self, nnodes: int, ppo_mini_batch_size: int, world_size: int):
self.nnodes = nnodes
self.ppo_mini_batch_size = ppo_mini_batch_size
self.world_size = world_size
# Use a deque to store tuples of metadata and references for efficient FIFO operations
self._sample_queue: deque[Tuple[SampleInfo, ray.ObjectRef]] = deque()
self._put_counter = 0 # Used for round-robin buffer selection
self.lock = asyncio.Lock()
loguru.logger.info("Global DataCoordinator initialized.")
# Cache for multi-rank access: [[rank0_data], [rank1_data], ...]
self._cache: List[List[ray.ObjectRef]] = []
self._cache_key: Optional[str] = None # Track which key the cache belongs to
# === Statistics tracking for dynamic sampling scenarios ===
self._stats_batches_received = 0 # Number of put_batch calls
self._stats_samples_received = 0 # Total samples received since last dispatch
self._stats_accumulation_start = None # Time when accumulation started
self._stats_last_progress_pct = 0 # Last logged progress percentage
async def put(self, sample_info: SampleInfo, sample_ref: Any, caller_node_id: Optional[str] = None):
"""
Called by a RolloutWorker to register a new sample reference and its metadata.
This method automatically routes the ObjectRef to a DataBuffer on its local
node to be held.
Args:
sample_info: Metadata about the sample
sample_ref: Ray ObjectRef or the actual sample data
caller_node_id: The node ID of the caller. If None, will try to get it from
the runtime context (but this won't work correctly for remote calls)
"""
# Due to Ray's small object optimization, an ObjectRef passed by the client
# might be automatically resolved to its actual value. Here, we ensure that
# we are always handling an ObjectRef.
if not isinstance(sample_ref, ray.ObjectRef):
sample_ref = ray.put(sample_ref)
# 1. Get the node ID of the caller
# Note: When called remotely, ray.get_runtime_context().get_node_id() returns
# the node ID of the DataCoordinator actor, not the caller. So we require the
# caller to pass their node_id explicitly.
if caller_node_id is None:
caller_node_id = ray.get_runtime_context().get_node_id()
loguru.logger.warning(
"DataCoordinator.put() called without caller_node_id. "
f"Using DataCoordinator's node_id {caller_node_id[:16]}... which may be incorrect."
)
# 2. Inject the node ID into SampleInfo for subsequent filtering
# Only inject if node_id has not been manually set, to facilitate testing.
if sample_info.node_id is None:
sample_info.node_id = caller_node_id
# 3. Register the metadata and reference to the global queue + update statistics
async with self.lock:
self._sample_queue.append((sample_info, sample_ref))
# Update statistics (single sample put)
self._stats_batches_received += 1
self._stats_samples_received += 1
if self._stats_accumulation_start is None:
self._stats_accumulation_start = time.time()
async def put_batch(self, sample_infos: List[SampleInfo], sample_refs: List[ray.ObjectRef], caller_node_id: Optional[str] = None):
"""
Called by a worker to register a batch of new sample references and their metadata.
This method routes the ObjectRefs to DataBuffers on their local nodes.
Args:
sample_infos: List of metadata for each sample
sample_refs: List of Ray ObjectRefs
caller_node_id: The node ID of the caller. If None, will try to get it from
the runtime context (but this won't work correctly for remote calls)
"""
if not sample_refs:
return
# Get the node ID of the caller
# Note: When called remotely, ray.get_runtime_context().get_node_id() returns
# the node ID of the DataCoordinator actor, not the caller. So we require the
# caller to pass their node_id explicitly.
if caller_node_id is None:
caller_node_id = ray.get_runtime_context().get_node_id()
loguru.logger.warning(
"DataCoordinator.put_batch() called without caller_node_id. "
f"Using DataCoordinator's node_id {caller_node_id[:16]}... which may be incorrect."
)
for i in range(len(sample_infos)):
if sample_infos[i].node_id is None:
sample_infos[i].node_id = caller_node_id
async with self.lock:
self._sample_queue.extend(zip(sample_infos, sample_refs))
# Update statistics
self._stats_batches_received += 1
self._stats_samples_received += len(sample_refs)
if self._stats_accumulation_start is None:
self._stats_accumulation_start = time.time()
async def get_batch(
self,
batch_size: int,
dp_rank: int,
filter_plugin: Optional[Callable[[SampleInfo], bool]] = None,
balance_partitions: Optional[int] = None,
cache_key: Optional[str] = None
) -> List[ray.ObjectRef]:
"""Called by a Trainer to get a batch of sample ObjectRefs.
Supports caching for multi-rank access within the same node, filter plugins
for custom sampling logic, and length balancing across partitions.
Args:
batch_size: The requested batch size per partition.
dp_rank: The data parallel rank requesting data.
filter_plugin: Optional filter function for custom sampling logic.
balance_partitions: Number of partitions for length balancing (typically dp_size).
cache_key: Key to identify the cache (e.g., node name like 'compute_reward').
Different keys invalidate the cache to prevent data mixing.
Returns:
A list of sample ObjectRefs for the specified dp_rank.
"""
async with self.lock:
global_batch_size = batch_size * balance_partitions
# === Phase 1: Check cache validity ===
if self._cache:
if cache_key is not None and self._cache_key == cache_key:
# Same key: return cached data (allows multiple reads by tp/pp ranks)
if dp_rank < len(self._cache):
return self._cache[dp_rank]
return []
# Different key: invalidate old cache
self._cache = []
self._cache_key = None
# === Phase 2: Fetch data from queue ===
if filter_plugin:
# With filter: O(N) scan
if isinstance(filter_plugin, list):
batch_items = [item for item in self._sample_queue
if all(f(item[0]) for f in filter_plugin)]
else:
batch_items = [item for item in self._sample_queue
if filter_plugin(item[0])]
if len(batch_items) < global_batch_size:
self._log_accumulation_progress(len(batch_items), global_batch_size)
return []
# Take only what we need and remove from queue
batch_items = batch_items[:global_batch_size]
refs_to_remove = {item[1] for item in batch_items}
self._sample_queue = deque(
item for item in self._sample_queue if item[1] not in refs_to_remove
)
else:
# No filter: efficient FIFO
if len(self._sample_queue) < global_batch_size:
self._log_accumulation_progress(len(self._sample_queue), global_batch_size)
return []
batch_items = [self._sample_queue.popleft() for _ in range(global_batch_size)]
# === Phase 3: Apply length balancing ===
if balance_partitions and balance_partitions > 1:
batch_refs = self._apply_length_balancing(batch_items, balance_partitions)
else:
batch_refs = [item[1] for item in batch_items]
# === Phase 4: Build cache (unified nested list structure) ===
self._cache = [
batch_refs[rank * batch_size:(rank + 1) * batch_size]
for rank in range(balance_partitions)
]
self._cache_key = cache_key
self._log_dispatch_stats(global_batch_size)
return self._cache[dp_rank]
def _log_accumulation_progress(self, current_samples: int, target_samples: int):
"""Log progress milestones at INFO level when reaching 25%, 50%, 75%."""
if target_samples <= 0:
return
current_pct = int(current_samples * 100 / target_samples)
# Log at 25%, 50%, 75% milestones (only once per milestone)
# Log the highest crossed milestone that hasn't been logged yet
milestones = [25, 50, 75]
highest_crossed = None
for milestone in milestones:
if current_pct >= milestone and self._stats_last_progress_pct < milestone:
highest_crossed = milestone
if highest_crossed is not None:
wait_time = time.time() - self._stats_accumulation_start if self._stats_accumulation_start else 0
loguru.logger.info(
f"[DataCoordinator] Accumulation {highest_crossed}%: {current_samples}/{target_samples} samples "
f"({self._stats_batches_received} batches, {wait_time:.1f}s)"
)
self._stats_last_progress_pct = highest_crossed
def _log_dispatch_stats(self, dispatched_samples: int):
"""Log statistics when dispatching a batch and reset counters."""
wait_time = time.time() - self._stats_accumulation_start if self._stats_accumulation_start else 0
avg_samples_per_batch = (
self._stats_samples_received / self._stats_batches_received
if self._stats_batches_received > 0 else 0
)
total_received = self._stats_samples_received
remaining_in_queue = total_received - dispatched_samples
loguru.logger.info(
f"[DataCoordinator DISPATCH] "
f"Accumulated: {total_received} samples from {self._stats_batches_received} batches | "
f"Dispatching: {dispatched_samples} samples | "
f"Remaining in queue: {remaining_in_queue} | "
f"Avg per batch: {avg_samples_per_batch:.1f} | "
f"Wait: {wait_time:.1f}s"
)
self._stats_batches_received = 0
self._stats_samples_received = 0
self._stats_accumulation_start = None
self._stats_last_progress_pct = 0
def _apply_length_balancing(
self,
batch_items: List[Tuple[SampleInfo, ray.ObjectRef]],
k_partitions: int,
keep_mini_batch = False
) -> List[ray.ObjectRef]:
"""Applies the length balancing algorithm to reorder samples.
Uses the LPT (Longest Processing Time) algorithm to reorder samples so that
if they are evenly distributed among k_partitions workers, the sum of
sample lengths for each worker is as balanced as possible.
Supports Group N: samples with the same uid will be assigned to the same partition,
ensuring correct group-relative advantage computation for GRPO and similar algorithms.
Args:
batch_items: A list of (SampleInfo, ObjectRef) tuples.
k_partitions: The number of partitions (typically the DP size).
keep_mini_batch: Whether to keep mini-batch structure during balancing.
Returns:
A reordered list of ObjectRefs.
"""
# ========== Step 1: Group samples by uid ==========
uid_to_indices = defaultdict(list)
for idx, (sample_info, _) in enumerate(batch_items):
uid = sample_info.uid if sample_info.uid is not None else str(idx)
uid_to_indices[uid].append(idx)
# Check if grouping is needed (max_group_size > 1 means we have Group N)
max_group_size = max(len(indices) for indices in uid_to_indices.values()) if uid_to_indices else 1
if max_group_size == 1:
# No grouping needed, use original single-sample balancing logic
return self._apply_length_balancing_single_sample(batch_items, k_partitions, keep_mini_batch)
# ========== Step 2: Calculate workload for each Group ==========
group_list = list(uid_to_indices.keys()) # All unique uids
group_workloads = []
for uid in group_list:
indices = uid_to_indices[uid]
# Group workload = sum of all samples' sum_tokens in the group
total_tokens = sum(batch_items[i][0].sum_tokens for i in indices)
group_workloads.append(total_tokens)
# ========== Step 3: Balance Groups across partitions ==========
workload_lst = calculate_workload(group_workloads)
# Check if number of groups is divisible by k_partitions
num_groups = len(group_list)
if num_groups < k_partitions:
loguru.logger.warning(
f"Number of groups ({num_groups}) is less than partitions ({k_partitions}). "
f"Some partitions will be empty. Falling back to single-sample balancing."
)
return self._apply_length_balancing_single_sample(batch_items, k_partitions, keep_mini_batch)
equal_size = num_groups % k_partitions == 0
if not equal_size:
loguru.logger.warning(
f"Number of groups ({num_groups}) is not divisible by partitions ({k_partitions}). "
f"Some partitions may have uneven group counts."
)
# Partition groups across workers
group_partitions = get_seqlen_balanced_partitions(workload_lst, k_partitions=k_partitions, equal_size=equal_size)
# ========== Step 4: Expand groups to samples, keeping group integrity ==========
reordered_refs = []
for partition_group_indices in group_partitions:
for group_idx in partition_group_indices:
uid = group_list[group_idx]
sample_indices = uid_to_indices[uid]
# Add all samples of the same group together, preserving original order within group
for sample_idx in sample_indices:
reordered_refs.append(batch_items[sample_idx][1])
loguru.logger.debug(
f"Applied GROUP-aware length balancing: "
f"{len(batch_items)} samples in {num_groups} groups (group_size={max_group_size}) "
f"reordered into {k_partitions} partitions"
)
return reordered_refs
def _apply_length_balancing_single_sample(
self,
batch_items: List[Tuple[SampleInfo, ray.ObjectRef]],
k_partitions: int,
keep_mini_batch=False,
) -> List[ray.ObjectRef]:
"""Original length balancing logic for single samples (no UID grouping).
This is used when there's no Group N (each uid has only one sample).
Args:
batch_items: A list of (SampleInfo, ObjectRef) tuples.
k_partitions: The number of partitions (typically the DP size).
keep_mini_batch: Whether to keep mini-batch structure during balancing.
Returns:
A reordered list of ObjectRefs.
"""
# Extract the length of each sample.
# Use sum_tokens as the length metric (includes prompt + response).
seqlen_list = [item[0].sum_tokens for item in batch_items]
# Use the karmarkar_karp balance
workload_lst = calculate_workload(seqlen_list)
# Decouple the DP balancing and mini-batching.
if keep_mini_batch:
minibatch_size = self.ppo_mini_batch_size
minibatch_num = len(workload_lst) // minibatch_size
global_partition_lst = [[] for _ in range(self.world_size)]
for i in range(minibatch_num):
rearrange_minibatch_lst = get_seqlen_balanced_partitions(
workload_lst[i * minibatch_size : (i + 1) * minibatch_size],
k_partitions=self.world_size,
equal_size=True,
)
for j, part in enumerate(rearrange_minibatch_lst):
global_partition_lst[j].extend([x + minibatch_size * i for x in part])
else:
global_partition_lst = get_seqlen_balanced_partitions(
workload_lst, k_partitions=self.world_size, equal_size=True
)
# Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
for idx, partition in enumerate(global_partition_lst):
partition.sort(key=lambda x: (workload_lst[x], x))
ordered_partition = partition[::2] + partition[1::2][::-1]
global_partition_lst[idx] = ordered_partition
# Reorder the samples based on the partitioning result.
# Concatenate the partitions in order: [all samples from partition_0, all from partition_1, ...]
reordered_refs = []
for partition in global_partition_lst:
for original_idx in partition:
reordered_refs.append(batch_items[original_idx][1])
loguru.logger.debug(
f"Applied length balancing: {len(batch_items)} samples reordered into {k_partitions} partitions"
)
return reordered_refs
async def get_all_by_filter(self, filter_plugin: Callable[[SampleInfo], bool]) -> List[ray.ObjectRef]:
"""
Gets ALL sample ObjectRefs that match the filter plugin, consuming them from the queue.
This is useful for pipeline-based data passing where a downstream stage needs the
entire output of an upstream stage.
"""
async with self.lock:
# 1. Find all items that match the filter.
items_to_return = [item for item in self._sample_queue if filter_plugin(item[0])]
if not items_to_return:
return []
# 2. Extract their ObjectRefs.
batch_refs = [item[1] for item in items_to_return]
# 3. Efficiently remove the selected items from the original queue.
refs_to_remove = {ref for ref in batch_refs}
self._sample_queue = deque(item for item in self._sample_queue if item[1] not in refs_to_remove)
return batch_refs
async def get_valid_size(self) -> int:
"""Returns the number of samples in the current queue."""
async with self.lock:
return len(self._sample_queue)
async def peek_source_dp_size(self, filter_plugin: Callable[[SampleInfo], bool]) -> Optional[int]:
"""
Peek at the source_dp_size of matching samples without consuming them.
Args:
filter_plugin: Filter function to find matching samples
Returns:
The source_dp_size if found, None otherwise
"""
async with self.lock:
for sample_info, _ in self._sample_queue:
if filter_plugin(sample_info):
source_dp_size = sample_info.dict_info.get('source_dp_size')
if source_dp_size is not None:
return source_dp_size
return None
def reset_cache(self):
"""Reset the coordinator state for a new training step."""
loguru.logger.info("Resetting DataCoordinator cache")
self._sample_queue.clear()
self._cache = []
self._cache_key = None
def __repr__(self) -> str:
return f""
# ====================================================================
# Initialization Logic
# ====================================================================
def init_data_coordinator(num_buffers: int, ppo_mini_batch_size: int, world_size: int) -> ray.actor.ActorHandle:
"""
Initializes the data coordination system, which includes a global DataCoordinator
and multiple distributed DataBuffers. Returns a single, unified DataCoordinator
handle to the user.
Args:
num_buffers: The number of distributed DataBuffer instances to create,
usually equal to the number of nodes or total GPUs.
force_local: If True, forces all Buffers to be created on the local node,
for single-machine testing.
Returns:
The Actor handle for the DataCoordinator.
"""
if not ray.is_initialized():
raise RuntimeError("Ray must be initialized before calling init_data_coordinator.")
# 1. Create or get the globally unique DataCoordinator
# Use a global name to ensure the coordinator's uniqueness
coordinator_name = "global_data_coordinator"
try:
coordinator = ray.get_actor(coordinator_name)
loguru.logger.info(f"Connected to existing DataCoordinator actor '{coordinator_name}'.")
except ValueError:
loguru.logger.info(f"Creating new DataCoordinator actor with global name '{coordinator_name}'.")
coordinator = DataCoordinator.options(name=coordinator_name, lifetime="detached").remote(nnodes=num_buffers, ppo_mini_batch_size=ppo_mini_batch_size, world_size=world_size)
return coordinator
================================================
FILE: siirl/data_coordinator/dataloader/__init__.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .data_loader_node import DataLoaderNode
__all__ = ["DataLoaderNode"]
================================================
FILE: siirl/data_coordinator/dataloader/data_loader_node.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterator, Optional
import torch
from loguru import logger
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from siirl.execution.dag import Node, NodeRole, NodeStatus, NodeType
from siirl.models.loader import load_tokenizer
from siirl.params import SiiRLArguments
from siirl.data_coordinator.dataloader.partitioned_dataset import PartitionedRLHFDataset
class RepeatDataset(torch.utils.data.Dataset):
"""
A dataset wrapper that repeats the base dataset multiple times.
This class allows you to virtually extend the length of a given dataset by repeating its samples
a specified number of times. It is useful for scenarios where you want to train for more epochs
without reloading or reshuffling the data, or to balance datasets by oversampling.
Args:
base_dataset (torch.utils.data.Dataset): The original dataset to be repeated.
repeat_factor (int): The number of times to repeat the base dataset.
Attributes:
base_dataset (torch.utils.data.Dataset): The original dataset.
repeat_factor (int): The number of repetitions.
length (int): The total length of the repeated dataset.
Example:
>>> base_dataset = MyCustomDataset()
>>> repeated_dataset = RepeatDataset(base_dataset, repeat_factor=3)
>>> len(repeated_dataset) == 3 * len(base_dataset)
True
"""
def __init__(self, base_dataset, repeat_factor):
self.base_dataset = base_dataset
self.repeat_factor = repeat_factor
self.length = len(base_dataset) * repeat_factor
def __len__(self):
return self.length
def __getitem__(self, idx):
return self.base_dataset[idx % len(self.base_dataset)]
class DataLoaderNode(Node):
"""
Represents a data loader node in the DAG.
This version uses the PartitionedRLHFDataset for efficient, memory-safe
distributed data loading. Each rank only loads and processes its own data slice.
"""
def __init__(
self, node_id: str, global_config: SiiRLArguments, config: Optional[Dict[str, Any]] = None, retry_limit: int = 0
):
"""
Initialize a data loader node.
Args:
node_id (str): The unique identifier of the node.
global_config(SiiRLArguments): The arguments from config file.
config (Optional[Dict[str, Any]]): Specific configuration information for the node. Defaults to an empty dictionary.
retry_limit (int): The maximum number of retries when the node execution fails. Defaults to 0 (no retries).
"""
super().__init__(node_id, NodeType.DATA_LOAD, NodeRole.DEFAULT, config=config, retry_limit=retry_limit)
self.global_config = global_config
if "tokenizer" in self.config:
self.tokenizer = self.config["tokenizer"]
self.processor = self.config["processor"]
else:
# Load tokenizer and processor
tokenizer_module = load_tokenizer(model_args=global_config.actor_rollout_ref.model)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
# force load in main process for vision language model
self.num_loader_workers = (
0
if global_config.actor_rollout_ref.rollout.name == "sglang" or self.processor is not None
else config.get("num_loader_workers", 8)
)
# Get group world size, rank, parallel size from config.
# Group world size means the rollout pytorch distributed group total gpus.
# Group rank means the process index in distributed group.
# Group parallel size means the rollout total parallel size, e.g. tp_size * pp_size
self.group_world_size = config["group_world_size"]
self.group_rank = config["group_rank"]
self.group_parallel_size = config["group_parallel_size"]
if self.group_world_size % self.group_parallel_size != 0:
# Log an error or raise an exception if world_size is not divisible by group_parallel_size
error_msg = f"group_world_size ({self.group_world_size}) must be divisible by group_parallel_size ({self.group_parallel_size})."
logger.error(error_msg)
raise ValueError(error_msg)
# Calculate the world size and rank for rollout data parallelism, which is actually needed for data partitioning.
self.rollout_ddp_world_size = self.group_world_size // self.group_parallel_size
self.rollout_ddp_rank = self.group_rank // self.group_parallel_size
self._current_train_iter: Optional[Iterator] = None
self._current_val_iter: Optional[Iterator] = None
self._current_epoch: int = -1
self._create_dataloader()
self.num_train_batches = len(self.train_dataloader) if self.train_dataloader else 0
self.num_val_batches = len(self.val_dataloader) if self.val_dataloader else 0
logger.info(f"DataLoaderNode '{self.node_id}' initialized:")
logger.info(f" Group rank: {self.group_rank} / {self.group_world_size}")
logger.info(f" Rollout DDP rank: {self.rollout_ddp_rank} / {self.rollout_ddp_world_size}")
logger.info(f" Train batches per epoch for this rank: {self.num_train_batches}")
logger.info(f" Total training steps (approx): {self.total_training_steps}")
def _create_dataloader(self):
"""
Initializes and configures the training and validation DataLoaders for RLHF tasks.
When enable `auto_repeat`, if the dataset is too small to form a batch, it will be automatically repeated
until at least one batch can be formed.
This method performs the following steps:
1. Creates the training dataset (`PartitionedRLHFDataset`) with the provided configuration, tokenizer, processor, and distributed data parallel (DDP) settings.
2. Sets up the sampler for the training DataLoader:
- Uses a `RandomSampler` with a seeded generator if shuffling is enabled in the configuration.
- Uses a `SequentialSampler` otherwise.
3. Configures the tokenizer's padding side to "left" to ensure correct sequence alignment.
4. Creates the training DataLoader (`StatefulDataLoader`) with the specified batch size, number of workers, sampler, and collator.
5. Creates the validation dataset and DataLoader, using the full dataset as a single batch for evaluation.
6. Asserts that the training DataLoader contains at least one batch.
7. Calculates the total number of training steps based on the number of batches and epochs, or uses a user-specified value if provided.
8. Updates the total training steps in the optimizer configurations for both the actor and critic components.
"""
# Create the partitioned training dataset for this rank
self.train_dataset = PartitionedRLHFDataset(
config=self.global_config,
tokenizer=self.tokenizer,
processor=self.processor,
ddp_rank=self.rollout_ddp_rank,
ddp_world_size=self.rollout_ddp_world_size,
is_eval=False,
drop_last=self.config.get("train_drop_last", True),
)
# Calculate batch size per rank
train_batch_size = self.global_config.data.train_batch_size // self.rollout_ddp_world_size
self.train_batch_size = train_batch_size
# Auto-repeat logic: if dataset is too small, repeat it until at least one batch can be formed
auto_repeat = self.config.get("auto_repeat", False)
train_len = len(self.train_dataset)
if auto_repeat and train_len < train_batch_size:
repeat_factor = (train_batch_size + train_len - 1) // train_len
self.train_dataset = RepeatDataset(self.train_dataset, repeat_factor)
logger.warning(
f"Rank {self.rollout_ddp_rank}: Training dataset too small (size={train_len}), "
f"auto-repeating {repeat_factor} times to ensure at least one batch (batch_size={train_batch_size}). "
f"Now RepeatDataset size={len(self.train_dataset)}"
)
# Choose sampler: RandomSampler with seed if shuffle enabled, else SequentialSampler
if self.global_config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.global_config.trainer.seed)
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
# Create the training dataloader with the specified batch size, workers, sampler, and collator
from siirl.data_coordinator.dataloader.partitioned_dataset import collate_fn as default_collate_fn
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=train_batch_size,
num_workers=self.num_loader_workers,
drop_last=True,
collate_fn=default_collate_fn,
sampler=sampler,
)
# Create the partitioned validation dataset for this rank
self.val_dataset = PartitionedRLHFDataset(
config=self.global_config,
tokenizer=self.tokenizer,
processor=self.processor,
ddp_rank=self.rollout_ddp_rank,
ddp_world_size=self.rollout_ddp_world_size,
is_eval=True,
drop_last=self.config.get("eval_drop_last", False),
)
# Create the validation dataloader, loading the entire validation set as one batch
val_batch_size = self.global_config.data.val_batch_size # Prefer config value if set
if val_batch_size is None:
val_batch_size = len(self.val_dataset)
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=val_batch_size,
num_workers=self.num_loader_workers,
shuffle=False,
drop_last=False,
collate_fn=default_collate_fn,
)
# Assert that there is at least one batch for this rank
assert (
len(self.train_dataloader) >= 1
), f"Not enough data for current rank (rank id: {self.rollout_ddp_rank}) to consume. Please increase the train datasets or reduce the number of GPUs."
assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"
# Calculate the number of batches and total training steps
num_batches = len(self.train_dataloader) if self.train_dataloader else 0
total_training_steps = num_batches * self.global_config.trainer.total_epochs
# Use user-specified total_training_steps if provided
if self.global_config.trainer.total_training_steps is not None:
total_training_steps = self.global_config.trainer.total_training_steps
self.total_training_steps = total_training_steps
# Update total training steps in optimizer configs for actor and critic
self.global_config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.global_config.critic.optim.total_training_steps = total_training_steps
# Indicates the samples for this rank has already been expand
self.is_val_trailing_rank = self.val_dataset.is_trailing_rank
def _reinit_dataloader_sampler(self):
"""
Re-initializes the sampler and dataloader to clear any internal state (like being exhausted).
This is useful when resuming from a checkpoint that was saved at the end of an epoch.
"""
# Re-create the sampler
if self.global_config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.global_config.trainer.seed)
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
# Re-create the dataloader with the new sampler
from siirl.data_coordinator.dataloader.partitioned_dataset import collate_fn as default_collate_fn
train_batch_size = self.global_config.data.train_batch_size // self.rollout_ddp_world_size
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=train_batch_size,
num_workers=self.num_loader_workers,
drop_last=True,
collate_fn=default_collate_fn,
sampler=sampler,
)
logger.info(f"Node {self.node_id}: Re-initialized dataloader and sampler.")
def get_train_dataloader(self):
"""
Returns the training data loader.
Returns:
DataLoader: The data loader used for training.
"""
return self.train_dataloader
def get_val_dataloader(self):
"""
Returns the validation dataloader.
Returns:
DataLoader: The dataloader used for validation data.
"""
return self.val_dataloader
def run(self, epoch: Optional[int] = None, is_validation_step: bool = False, **kwargs: Any) -> Any:
"""
Executes the data loading process for a given step or validation.
Args:
epoch (Optional[int]): The current epoch number. Required for training steps to handle
sampler state (e.g., DistributedSampler.set_epoch()).
is_validation_step (bool): Flag indicating if validation data is requested.
**kwargs: Additional arguments (not used directly in this basic version but
part of the Node.execute signature).
Returns:
Any: A batch of data. The structure depends on the collate_fn.
Raises:
ValueError: If epoch is not provided for a training step.
StopIteration: If the dataloader is exhausted and cannot provide more data
(though this might be handled by the DAG scheduler).
"""
self.update_status(NodeStatus.RUNNING)
logger.debug(f"Node {self.node_id} execute: epoch={epoch}, is_validation_step={is_validation_step}")
try:
if is_validation_step:
if not self.val_dataloader: # Handles empty validation dataset
logger.warning(f"Rank {self.group_rank}: Validation dataloader is not available or empty.")
self.update_status(NodeStatus.COMPLETED) # Or FAILED if this is an error condition
return None # Or an empty batch marker
# Validation dataloader loads the entire validation set as one batch.
# We get a fresh iterator each time for validation.
if self._current_val_iter is None:
self._current_val_iter = iter(self.val_dataloader)
try:
batch = next(self._current_val_iter)
logger.debug(f"Node {self.node_id}: Yielding validation batch.")
# Reset for next validation call, as it's one batch
self._current_val_iter = None
except StopIteration:
logger.warning(
f"Node {self.node_id}: Validation dataloader exhausted unexpectedly (should be one batch). Resetting."
)
# This case should ideally not happen if batch_size = len(dataset) and it's not empty
self._current_val_iter = iter(self.val_dataloader) # Get a fresh iterator
try:
batch = next(self._current_val_iter)
except StopIteration:
logger.error(f"Node {self.node_id}: Validation dataloader is empty even after reset.")
self.update_status(NodeStatus.FAILED, "Validation dataloader empty")
return None
else: # Training step
if epoch is None:
error_msg = "Epoch must be provided for training steps."
logger.error(f"Node {self.node_id}: {error_msg}")
self.update_status(NodeStatus.FAILED, error_msg)
raise ValueError(error_msg)
if not self.train_dataloader: # Handles empty training dataset
logger.warning(f"Rank {self.group_rank}: Training dataloader is not available or empty.")
self.update_status(NodeStatus.COMPLETED) # Or FAILED
return None # Or an empty batch marker
# Flag to track if we just created the iterator
iterator_just_created = False
if self._current_epoch != epoch or self._current_train_iter is None:
logger.info(f"Node {self.node_id}: New epoch ({epoch}) or first step. Initializing train iterator.")
self._current_epoch = epoch
# Set epoch for DistributedSampler if applicable
if hasattr(self.train_dataloader.sampler, "set_epoch") and isinstance(
self.train_dataloader.sampler, DistributedSampler
):
logger.debug(f"Node {self.node_id}: Setting epoch {epoch} for DistributedSampler.")
self.train_dataloader.sampler.set_epoch(epoch)
self._current_train_iter = iter(self.train_dataloader)
iterator_just_created = True
try:
batch = next(self._current_train_iter)
logger.debug(f"Node {self.node_id}: Yielding training batch for epoch {epoch}.")
except StopIteration:
# FIX: Handle resume from end-of-epoch state
if iterator_just_created:
logger.warning(
f"Node {self.node_id}: Iterator exhausted immediately after creation. "
f"This indicates resumption from a completed epoch state. "
f"Resetting dataloader to start fresh for epoch {epoch}."
)
# Re-create the dataloader to clear the internal "exhausted" state
# We keep the dataset to avoid reloading heavy data
self._reinit_dataloader_sampler()
self._current_train_iter = iter(self.train_dataloader)
batch = next(self._current_train_iter)
else:
# Real end of epoch
error_msg = (
f"Training dataloader exhausted for epoch {epoch}. This might be expected at epoch end."
)
logger.info(f"Node {self.node_id}: {error_msg}")
# We might not want to mark FAILED here, as it's a natural end of an iterator.
# The caller (DAG executor) should decide if more data was expected.
# For now, let's re-raise StopIteration to signal the caller.
self.update_status(NodeStatus.COMPLETED) # Or a custom status like 'EPOCH_END'
raise # Re-raise StopIteration
self.update_status(NodeStatus.COMPLETED)
return batch
except Exception as e:
error_msg = f"Error during data loading in node {self.node_id}: {e}"
logger.exception(error_msg) # Log with stack trace
self.update_status(NodeStatus.FAILED, str(e))
raise # Re-raise the exception so the DAG executor can handle it
def state_dict(self) -> Dict[str, Any]:
"""
Captures the state of the DataLoaderNode, primarily the training dataloader's state.
Returns:
Dict[str, Any]: A dictionary containing the node's state.
"""
return {
"train_dataloader_state": self.train_dataloader.state_dict(),
}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""
Restores the state of the DataLoaderNode from a state dictionary.
Args:
state_dict (Dict[str, Any]): The state dictionary to load.
"""
if "train_dataloader_state" in state_dict:
self.train_dataloader.load_state_dict(state_dict["train_dataloader_state"])
# After loading state, the current iterator is invalid because it's tied to the old
# sampler state. Setting it to None forces the run() method to create a new,
# valid iterator that is synchronized with the restored state.
self._current_train_iter = None
logger.info(
f"Node {self.node_id} (Rank {self.group_rank}): Successfully loaded train_dataloader state. Iterator will be reset on next call."
)
================================================
FILE: siirl/data_coordinator/dataloader/embodied_preprocess.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Tuple
import pandas as pd
from libero.libero import benchmark
from loguru import logger
def prepare_libero_train_valid_datasets(
task_suite_name: str,
num_trials_per_task: int,
dataset_dir: str,
) -> Tuple[Path, Path]:
"""
Generates identical training and validation dataset manifests for a LIBERO task suite.
This manifest contains the necessary metadata for the VLA agent to initialize environments.
The function queries the actual number of initial states for each task and generates
trial_ids up to min(actual_initial_states, num_trials_per_task) to ensure all
generated (task_id, trial_id) pairs are valid.
Args:
task_suite_name (str): The name of the task suite.
num_trials_per_task (int): The maximum number of trials to include for each task.
Actual trials may be less if a task has fewer initial states.
dataset_dir (str): The directory where the manifest files will be saved.
Examples:
- Task with 50 initial states, num_trials_per_task=100 → generates trial_ids 0-49
- Task with 150 initial states, num_trials_per_task=100 → generates trial_ids 0-99
- Task with 150 initial states, num_trials_per_task=200 → generates trial_ids 0-149
"""
# 1. --- Validate parameters and prepare output directory ---
if num_trials_per_task < 1:
raise ValueError("Number of trials per task must be at least 1")
output_path = Path(dataset_dir)
output_path.mkdir(parents=True, exist_ok=True)
train_file = output_path / "train.parquet"
valid_file = output_path / "validate.parquet"
# 2. --- Get task info from LIBERO benchmark ---
try:
task_suite = benchmark.get_benchmark_dict()[task_suite_name]()
num_tasks_in_suite = task_suite.n_tasks
except KeyError as err:
raise ValueError(
f"Task suite '{task_suite_name}' not found in benchmark."
) from err
logger.info(f"Found {num_tasks_in_suite} tasks in '{task_suite_name}'.")
logger.info(
f"Requested maximum of {num_trials_per_task} trials per task."
)
# 3. --- Query actual initial state counts for each task ---
logger.info("Querying initial state counts for each task...")
task_initial_state_counts = []
for task_id in range(num_tasks_in_suite):
initial_states = task_suite.get_task_init_states(task_id)
num_initial_states = len(initial_states)
task_initial_state_counts.append(num_initial_states)
logger.debug(
f"Task {task_id}: {num_initial_states} initial states available"
)
# 4. --- Generate records with per-task trial_id limits ---
logger.info("Generating dataset records with per-task limits...")
all_records = []
total_capped_tasks = 0
for task_id in range(num_tasks_in_suite):
actual_num_states = task_initial_state_counts[task_id]
# Cap at actual available initial states
max_trials_for_task = min(actual_num_states, num_trials_per_task)
# Log if this task is being capped
if max_trials_for_task < num_trials_per_task:
logger.info(
f"Task {task_id}: Capping at {max_trials_for_task} trials "
f"(has {actual_num_states} initial states, requested {num_trials_per_task})"
)
total_capped_tasks += 1
# Generate records for this task
for trial_id in range(max_trials_for_task):
all_records.append({
"task_suite_name": task_suite_name,
"task_id": task_id,
"trial_id": trial_id,
"prompt_id": f"{task_suite_name}_{task_id}_{trial_id}",
})
# Log summary statistics
expected_records = num_tasks_in_suite * num_trials_per_task
actual_records = len(all_records)
logger.info(
f"Generated {actual_records} records "
f"(expected {expected_records} if all tasks had {num_trials_per_task} states)"
)
if total_capped_tasks > 0:
logger.info(
f"{total_capped_tasks} out of {num_tasks_in_suite} tasks were capped "
f"due to insufficient initial states"
)
# 5. --- Save to both train and validation Parquet files ---
try:
if all_records:
df = pd.DataFrame(all_records)
df.to_parquet(train_file, index=False)
df.to_parquet(valid_file, index=False)
logger.success(
f"✅ VLA task manifests successfully saved to '{output_path}'."
)
else:
logger.warning("No records were generated for the VLA task manifest.")
except ImportError:
logger.error(
"`pandas` and `pyarrow` are required. Please run: `pip install pandas pyarrow`"
)
raise
except Exception as e:
logger.error(f"Error saving Parquet file for VLA manifest: {e}")
raise
return train_file, valid_file
================================================
FILE: siirl/data_coordinator/dataloader/partitioned_dataset.py
================================================
# Copyright 2025, Shanghai Innovation Institute. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from tqdm import tqdm
from typing import Dict, Optional, Sequence
import datasets
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from loguru import logger
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
import siirl.utils.model_utils.torch_functional as F
from siirl.utils.model_utils.model import compute_position_id_with_mask
from siirl.params import SiiRLArguments
def collate_fn(data_list: list[dict]) -> dict:
"""
Collate a batch of sample dicts into batched tensors and arrays.
Args:
data_list: List of dicts mapping feature names to torch.Tensor or other values.
Returns:
Dict where tensor entries are stacked into a torch.Tensor of shape
(batch_size, *dims) and non-tensor entries are converted to
np.ndarray of dtype object with shape (batch_size,).
"""
tensors = defaultdict(list)
non_tensors = defaultdict(list)
# Fields that should be converted to tensors for embodied tasks
embodied_tensor_fields = {'task_id', 'trial_id'}
for data in data_list:
for key, val in data.items():
if isinstance(val, torch.Tensor):
tensors[key].append(val)
elif key in embodied_tensor_fields and isinstance(val, (int, np.integer)):
tensors[key].append(torch.tensor(val, dtype=torch.int64))
else:
non_tensors[key].append(val)
for key, val in tensors.items():
tensors[key] = torch.stack(val, dim=0)
for key, val in non_tensors.items():
non_tensors[key] = np.array(val, dtype=object)
return {**tensors, **non_tensors}
class PartitionedRLHFDataset(Dataset):
"""
An efficient Dataset class for distributed training. It only load and process
the data partition (slice) of the RLHF dataset corresponding to the current DDP rank.
Args:
config (SiiRLArguments): Configuration object containing data and preprocessing arguments.
tokenizer (PreTrainedTokenizer): Tokenizer for processing text prompts.
processor (Optional[ProcessorMixin]): Optional processor for multi-modal data (e.g., images, videos).
ddp_rank (int): The rank of the current process in DDP.
ddp_world_size (int): Total number of DDP processes (world size).
is_eval (bool): Whether the dataset is for evaluation (True) or training (False).
drop_last (Optional[bool]): Whether to drop the last remainder
if total rows is not divisible by world size.
Defaults to True for training, False for evaluation.
Notes:
- This class is optimized for distributed training scenarios, ensuring each DDP process only
loads and processes its own data partition.
- Supports multi-modal data (text, images, videos) if a processor is provided.
- Handles prompt filtering, truncation, and tokenization according to configuration.
- Uses multiprocessing for efficient data preprocessing.
"""
def __init__(
self,
config: SiiRLArguments,
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin] = None,
ddp_rank: int = 0,
ddp_world_size: int = 1,
is_eval: bool = False,
drop_last: Optional[bool] = None,
):
super().__init__()
self.tokenizer = tokenizer
self.processor = processor
self.data_args = config.data
self._rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
self.ddp_rank = ddp_rank
self.ddp_world_size = ddp_world_size
self.is_eval = is_eval
self.drop_last = drop_last if drop_last is not None else (not is_eval)
# The dataset type determines the processing pipeline.
# 'llm': Standard processing for prompt-based datasets.
# 'embodied': Minimal processing for embodied task manifests, passing metadata through.
self.dataset_type = self.data_args.dataset_type
if self._rank == 0:
logger.info(f"Initializing dataset with dataset_type='{self.dataset_type}'.")
self.prompt_key = self.data_args.prompt_key
self.image_key = self.data_args.image_key
self.video_key = self.data_args.video_key
self.max_prompt_length = self.data_args.max_prompt_length
self.truncation = self.data_args.truncation
self.return_raw_chat = self.data_args.return_raw_chat
self.return_full_prompt = self.data_args.return_full_prompt
self.filter_overlong_prompts = self.data_args.filter_overlong_prompts
self.num_workers = self.data_args.preprocessing_num_workers if self.data_args.preprocessing_num_workers else max(1, os.cpu_count() // 8)
self.force_on_the_fly = config.data.force_on_the_fly
self.image_max_pixels = self.data_args.processor.image_max_pixels
self.image_min_pixels = self.data_args.processor.image_min_pixels
self.video_max_pixels = self.data_args.processor.video_max_pixels
self.video_min_pixels = self.data_args.processor.video_min_pixels
self.video_fps = self.data_args.processor.video_fps
self.video_maxlen = self.data_args.processor.video_maxlen
self.multi_turn = config.actor_rollout_ref.rollout.multi_turn.enable
self.is_trailing_rank = False # Indicates trailing ranks that received one less data item in round-robin partitioning.
if self._rank == 0:
logger.debug(f"Initializing PartitionedRLHFDataset with DDP rank {self.ddp_rank}, world size {self.ddp_world_size}, is_eval={self.is_eval}, drop_last={self.drop_last}")
if self.processor is not None:
logger.info(f"Set image_max_pixels={self.image_max_pixels}, image_min_pixels={self.image_min_pixels}, video_max_pixels={self.video_max_pixels}, video_min_pixels={self.video_min_pixels}, you can change these values via data.processor.image_max_pixels, etc.")
# 1. Load the raw data partition for the current DDP rank
dataset_files = self.data_args.val_files if is_eval else self.data_args.train_files
raw_dataframe = self._load_partitioned_raw_data(dataset_files)
if raw_dataframe is None or len(raw_dataframe) == 0:
logger.warning(f"DDP rank {self.ddp_rank} received no data.")
self.processed_dataframe = None
return
# 2. Filter out prompts that are too long from the loaded partition
raw_dataframe = self._filter_overlong_prompts(raw_dataframe)
# If this rank received fewer samples due to uneven partitioning, pad by duplicating the last row.
if self.is_trailing_rank and raw_dataframe is not None and len(raw_dataframe) > 0:
try:
# Duplicate the last row to pad the partition
last_row = raw_dataframe[-1].copy()
if "extra_info" in last_row and isinstance(last_row["extra_info"], dict):
last_row["extra_info"]["padded_duplicate"] = True
else:
last_row["extra_info"] = {"padded_duplicate": True}
last_row_ds = datasets.Dataset.from_list([last_row])
raw_dataframe = datasets.concatenate_datasets([raw_dataframe, last_row_ds])
logger.debug(f"DDP rank {self.ddp_rank} is a trailing rank, duplicating last row to pad partition. New length: {len(raw_dataframe)}")
except Exception:
# We can safely ignore this exception because we mainly rely on the 'padded_duplicate' flag to identify padded elements.
pass
# Based on the dataset type, we either perform full preprocessing or just
# pass the raw embodied metadata through.
# Initialize load_on_the_fly before branching to ensure it's always set
self.load_on_the_fly = self.force_on_the_fly
if self.dataset_type == "embodied":
# For embodied, we do no preprocessing. The dataset is just the raw manifest.
# The tokenizer and processor are not used at this stage.
self.processed_dataframe = raw_dataframe
if self._rank == 0:
logger.info("Embodied dataset type detected. Skipping tokenization and prompt processing.")
else:
# For 'llm' (default), proceed with the original preprocessing logic.
# 3. Preprocess the entire partition using multiple processes
# By only removing the specific prompt_key, we ensure that other columns,
# including complex types like dicts and strings from the original dataset,
# are preserved. The .map() function will then add the new columns
# returned by _preprocess_function. This is safer than removing all
# columns and rebuilding the dataset from scratch.
if self.load_on_the_fly:
self.processed_dataframe = raw_dataframe
else:
if self._rank == 0:
logger.warning("Currently preloading and preprocessing the entire dataset. If you encounter Out-Of-Memory issues, "
f"please set data.force_on_the_fly=True to enable on-the-fly loading mode. "
f"Now the dataset type is {self.dataset_type}.")
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
self.processed_dataframe = list(tqdm(
executor.map(self._preprocess_function, raw_dataframe),
total=len(raw_dataframe),
desc="Processing"
))
def _load_partitioned_raw_data(self, dataset_files: Sequence[str]) -> Optional[datasets.Dataset]:
"""
Loads a partition of Parquet data for the current DDP rank.
"""
if not dataset_files:
raise RuntimeError("No dataset files configured, aborting...")
try:
pq_files = [pq.ParquetFile(f) for f in dataset_files]
# Gather (file_idx, row_group_idx, num_rows, start_row_idx_global)
row_group_infos = []
total_rows = 0
for file_idx, pq_file in enumerate(pq_files):
for rg_idx in range(pq_file.num_row_groups):
num_rows = pq_file.metadata.row_group(rg_idx).num_rows
row_group_infos.append({"file_idx": file_idx, "row_group_idx": rg_idx, "num_rows": num_rows, "start_row_idx_global": total_rows})
total_rows += num_rows
if self._rank == 0:
logger.debug(f"DDP rank={self.ddp_rank}, row group infos: {row_group_infos}")
if total_rows < self.ddp_world_size:
raise RuntimeError(f"Total rows ({total_rows}) is less than DDP world size ({self.ddp_world_size}), "
f"cannot partition data across ranks. Please ensure enough data is available. "
f"Now the dataset type is {self.dataset_type}.")
# Compute partition indices
if self.drop_last:
rows_per_rank = total_rows // self.ddp_world_size
total_used_rows = rows_per_rank * self.ddp_world_size
start = self.ddp_rank * rows_per_rank
end = start + rows_per_rank
if self._rank == 0:
logger.warning(
f"DDP Rank {self.ddp_rank} using drop_last=True, partitioning rows into {self.ddp_world_size} ranks"
f"with {rows_per_rank} rows each. Total used rows: {total_used_rows}, start={start}, end={end}. "
f"Total rows: {total_rows}, total dropped rows: {total_rows - total_used_rows}. "
)
else:
# Distribute the remainder to the first (total_rows % ddp_world_size) ranks
rows_per_rank = total_rows // self.ddp_world_size
remainder = total_rows % self.ddp_world_size
if self.ddp_rank < remainder:
start = self.ddp_rank * (rows_per_rank + 1)
end = start + rows_per_rank + 1
else:
start = remainder * (rows_per_rank + 1) + (self.ddp_rank - remainder) * rows_per_rank
end = start + rows_per_rank
self.is_trailing_rank = True # There is one less sample compared to the previous ranks.
if start >= end:
raise RuntimeError(f"Rank {self.ddp_rank} assigned empty partition: start={start}, end={end}, total_rows={total_rows}")
# Find which row groups overlap with [start, end)
selected_chunks = []
for info in row_group_infos:
rg_start = info["start_row_idx_global"]
rg_end = rg_start + info["num_rows"]
# If overlap
if rg_end > start and rg_start < end:
# Compute local slice within this row group
local_start = max(0, start - rg_start)
local_end = min(info["num_rows"], end - rg_start)
selected_chunks.append({"file_idx": info["file_idx"], "row_group_idx": info["row_group_idx"], "local_start": local_start, "local_end": local_end})
# Read and slice the necessary row groups
tables = []
for chunk in selected_chunks:
pq_file = pq_files[chunk["file_idx"]]
table = pq_file.read_row_group(chunk["row_group_idx"])
if chunk["local_start"] > 0 or chunk["local_end"] < table.num_rows:
table = table.slice(chunk["local_start"], chunk["local_end"] - chunk["local_start"])
tables.append(table)
if not tables:
raise RuntimeError(f"DDP Rank {self.ddp_rank} assigned rows [{start}, {end}) but failed to read any data.")
final_table = pa.concat_tables(tables)
logger.debug(f"DDP rank={self.ddp_rank} loaded {len(final_table)} rows from {len(tables)} row groups. start={start}, end={end}, total_rows={total_rows}.")
return datasets.Dataset(final_table)
except Exception as e:
logger.error(f"Failed during partitioned data loading for DDP rank {self.ddp_rank}: {dataset_files}. Error: {e}")
raise
def _filter_overlong_prompts(self, raw_dataframe: datasets.Dataset) -> datasets.Dataset:
if self.filter_overlong_prompts:
original_len = len(raw_dataframe)
raw_dataframe = raw_dataframe.filter(
lambda doc: len(self.tokenizer.apply_chat_template(doc[self.prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
num_proc=self.num_workers,
desc=f"Rank {self.ddp_rank} filtering prompts longer than {self.max_prompt_length} tokens",
)
filtered_len = len(raw_dataframe)
if self._rank == 0:
logger.info(f"Filtered prompts from {original_len} to {filtered_len} on each rank.")
return raw_dataframe
def __len__(self) -> int:
return len(self.processed_dataframe) if self.processed_dataframe is not None else 0
def _build_messages(self, example: dict) -> list:
"""Helper function to structure messages, adopted from RLHFDataset."""
# The pop operation is safe here because map creates a copy for each process
messages: list = example.pop(self.prompt_key)
if self.image_key in example or self.video_key in example:
for message in messages:
content = message["content"]
content_list = []
# Simple split logic to handle interleaved text and images/videos
for segment in re.split("(|