Copy disabled (too large)
Download .txt
Showing preview only (13,545K chars total). Download the full file to get everything.
Repository: Kuaishou-OneRec/OpenOneRec
Branch: main
Commit: 1f4a99ea4708
Files: 1960
Total size: 12.5 MB
Directory structure:
gitextract_yy4wnsy3/
├── .gitignore
├── README.md
├── benchmarks/
│ ├── LICENSE
│ ├── README.md
│ ├── api/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── claude.py
│ │ ├── config/
│ │ │ └── llm_config.json
│ │ ├── deepseek.py
│ │ ├── example.py
│ │ └── gemini.py
│ ├── benchmark/
│ │ ├── __init__.py
│ │ ├── base_generator.py
│ │ ├── benchmark.py
│ │ ├── checkpoint_utils.py
│ │ ├── console.py
│ │ ├── generation_runner.py
│ │ ├── gpu_utils.py
│ │ └── tasks/
│ │ ├── __init__.py
│ │ ├── tasks.py
│ │ └── v1_0/
│ │ ├── __init__.py
│ │ ├── base_evaluator.py
│ │ ├── base_loader.py
│ │ ├── item_understand/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── evaluator.py
│ │ │ └── utils.py
│ │ ├── label_pred/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── evaluator.py
│ │ │ └── utils.py
│ │ ├── mfu_evaluator.py
│ │ ├── qwen3.jinja2
│ │ ├── qwen3_soft_switch.jinja2
│ │ ├── rec_reason/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── evaluator.py
│ │ │ └── utils.py
│ │ ├── recommendation/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── evaluator.py
│ │ │ ├── utils.py
│ │ │ └── utils_by_pid.py
│ │ └── registry.py
│ ├── eval_script.sh
│ ├── pyproject.toml
│ ├── requirements.txt
│ └── scripts/
│ ├── __init__.py
│ ├── eval_dev_results.py
│ ├── init_ray.sh
│ ├── init_ray_cluster.sh
│ └── ray-vllm/
│ ├── evaluate.py
│ └── utils/
│ ├── __init__.py
│ ├── arguments.py
│ └── generator.py
├── data/
│ ├── README.md
│ ├── general_text/
│ │ ├── pretrain.csv
│ │ └── sft.csv
│ ├── onerec_data/
│ │ ├── README.md
│ │ ├── pretrain/
│ │ │ ├── item_understand.py
│ │ │ ├── user_profile.py
│ │ │ └── video_rec.py
│ │ ├── run.sh
│ │ └── sft/
│ │ ├── ad_rec.py
│ │ ├── interactive_rec.py
│ │ ├── item_understand.py
│ │ ├── label_cond_rec.py
│ │ ├── label_pred.py
│ │ ├── product_rec.py
│ │ ├── rec_reason.py
│ │ └── video_rec.py
│ ├── prepare_distillation.sh
│ ├── prepare_pretrain.sh
│ ├── prepare_rl.sh
│ ├── prepare_sft.sh
│ └── scripts/
│ ├── parquet_unicode_fix.py
│ ├── sample_data.py
│ ├── split_data.py
│ └── train_test_split.py
├── pretrain/
│ ├── .gitignore
│ ├── README.md
│ ├── examples/
│ │ ├── dataset_config/
│ │ │ ├── pretrain.json
│ │ │ └── sft.json
│ │ ├── posttrain_sft.sh
│ │ ├── pretrain_stg1.sh
│ │ └── pretrain_stg2.sh
│ ├── onerec_llm/
│ │ ├── __init__.py
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── dataloaders.py
│ │ │ ├── local_shuffle_buffer.py
│ │ │ └── qwen3_dataset.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ └── ce.py
│ │ ├── models/
│ │ │ └── qwen3/
│ │ │ ├── __init__.py
│ │ │ ├── configuration_qwen3.py
│ │ │ ├── modeling_qwen3.py
│ │ │ └── modular_qwen3.py
│ │ ├── training/
│ │ │ ├── __init__.py
│ │ │ ├── activations.py
│ │ │ ├── checkpoint.py
│ │ │ ├── common.py
│ │ │ ├── distributed.py
│ │ │ ├── gradients.py
│ │ │ └── lr_schedulers.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── data_utils.py
│ │ ├── distributed.py
│ │ ├── ds_utils.py
│ │ ├── mfu_stats.py
│ │ ├── time_tracker.py
│ │ └── worker_utils.py
│ ├── recipes/
│ │ └── train_qwen3.py
│ ├── scripts/
│ │ ├── convert_checkpoint_to_hf.sh
│ │ ├── expand_qwen3_vocab.sh
│ │ ├── killall.sh
│ │ ├── numa_runner.sh
│ │ ├── test_cases_example.json
│ │ └── test_hf_model.sh
│ ├── set_env.sh
│ ├── tests/
│ │ └── test_qwen3_dataset_file_distribution.py
│ └── tools/
│ ├── model_converter/
│ │ ├── convert_checkpoint_to_hf.py
│ │ └── expand_qwen3_vocab.py
│ └── model_test/
│ └── test_hf_model.py
├── tokenizer/
│ ├── README.md
│ ├── infer_res_kmeans.py
│ ├── res_kmeans.py
│ └── train_res_kmeans.py
├── verl_distillation/
│ ├── LICENSE
│ ├── README.md
│ ├── README_ORIGINAL.md
│ ├── deploy_env.sh
│ ├── docker/
│ │ ├── Apptainerfile.rocm
│ │ ├── Dockerfile.extention.awsefa
│ │ ├── Dockerfile.ngc.vllm
│ │ ├── Dockerfile.ngc.vllm0.8
│ │ ├── Dockerfile.ngc.vllm0.8.sagemaker
│ │ ├── Dockerfile.rocm
│ │ ├── Dockerfile.rocm7
│ │ ├── Dockerfile.rocm_verl-0.3.0.post1
│ │ ├── Dockerfile.rocm_verl-0.4.1
│ │ ├── Dockerfile.sglang
│ │ ├── Dockerfile.vemlp.vllm.te
│ │ ├── Dockerfile.vllm.sglang.megatron.deepseek
│ │ ├── README.md
│ │ ├── ascend/
│ │ │ ├── Dockerfile.ascend_8.2.rc1_a2
│ │ │ └── Dockerfile.ascend_8.2.rc1_a3
│ │ ├── verl0.4-cu124-torch2.6-fa2.7.4/
│ │ │ ├── Dockerfile.app.sglang.vllm.mcore0.12
│ │ │ ├── Dockerfile.app.sglang.vllm.mcore0.12.deepep
│ │ │ ├── Dockerfile.app.sglang.vllm.mcore0.13.preview
│ │ │ ├── Dockerfile.app.vllm.mcore0.12
│ │ │ ├── Dockerfile.app.vllm.mcore0.12.deepep
│ │ │ ├── Dockerfile.app.vllm.mcore0.13.preview
│ │ │ ├── Dockerfile.base
│ │ │ └── README.md
│ │ ├── verl0.5-cu126-torch2.7-fa2.7.4/
│ │ │ ├── Dockerfile.app.sglang0.4.10.post2.mcore0.13
│ │ │ ├── Dockerfile.app.sglang0.4.9.post6.mcore0.13
│ │ │ ├── Dockerfile.app.vllm.mcore0.13
│ │ │ ├── Dockerfile.app.vllm.mcore0.15
│ │ │ ├── Dockerfile.base.torch2.7.1
│ │ │ └── README.md
│ │ ├── verl0.5-cu126-torch2.7.1-fa2.8.0/
│ │ │ ├── Dockerfile.app.sglang.mcore0.12
│ │ │ ├── Dockerfile.app.sglang.mcore0.13.preview
│ │ │ ├── Dockerfile.base
│ │ │ └── README.md
│ │ ├── verl0.5-preview-cu128-torch2.7.1-fa2.8.0/
│ │ │ ├── Dockerfile.app.sglang.megatron
│ │ │ ├── Dockerfile.base
│ │ │ └── README.md
│ │ └── verl0.6-cu128-torch2.8.0-fa2.7.4/
│ │ ├── Dockerfile.app.sglang
│ │ ├── Dockerfile.base
│ │ └── Dockerfile.vllm011.mcore_gpt-oss
│ ├── docs/
│ │ ├── Makefile
│ │ ├── README.md
│ │ ├── README_vllm0.7.md
│ │ ├── README_vllm0.8.md
│ │ ├── _static/
│ │ │ ├── custom.css
│ │ │ └── js/
│ │ │ ├── resizable-sidebar.js
│ │ │ └── runllm-widget.js
│ │ ├── advance/
│ │ │ ├── agent_loop.rst
│ │ │ ├── attention_implementation.rst
│ │ │ ├── checkpoint.rst
│ │ │ ├── dpo_extension.rst
│ │ │ ├── fsdp_extension.rst
│ │ │ ├── fully_async.md
│ │ │ ├── megatron_extension.rst
│ │ │ ├── one_step_off.md
│ │ │ ├── placement.rst
│ │ │ ├── ppo_lora.rst
│ │ │ ├── reward_loop.rst
│ │ │ ├── rollout_is.md
│ │ │ ├── rollout_skip.rst
│ │ │ ├── rollout_trace.rst
│ │ │ └── rope.rst
│ │ ├── algo/
│ │ │ ├── baseline.md
│ │ │ ├── collabllm.md
│ │ │ ├── dapo.md
│ │ │ ├── entropy.md
│ │ │ ├── gpg.md
│ │ │ ├── grpo.md
│ │ │ ├── opo.md
│ │ │ ├── ppo.md
│ │ │ ├── spin.md
│ │ │ └── sppo.md
│ │ ├── amd_tutorial/
│ │ │ ├── amd_build_dockerfile_page.rst
│ │ │ └── amd_vllm_page.rst
│ │ ├── api/
│ │ │ ├── data.rst
│ │ │ ├── single_controller.rst
│ │ │ ├── trainer.rst
│ │ │ └── utils.rst
│ │ ├── ascend_tutorial/
│ │ │ ├── ascend_profiling_en.rst
│ │ │ ├── ascend_profiling_zh.rst
│ │ │ ├── ascend_quick_start.rst
│ │ │ ├── ascend_sglang_quick_start.rst
│ │ │ └── dockerfile_build_guidance.rst
│ │ ├── conf.py
│ │ ├── data/
│ │ │ └── transfer_queue.md
│ │ ├── examples/
│ │ │ ├── config.rst
│ │ │ ├── gsm8k_example.rst
│ │ │ ├── multi_modal_example.rst
│ │ │ ├── ppo_code_architecture.rst
│ │ │ ├── sandbox_fusion_example.rst
│ │ │ └── skypilot_examples.rst
│ │ ├── faq/
│ │ │ └── faq.rst
│ │ ├── hybrid_flow.rst
│ │ ├── index.rst
│ │ ├── perf/
│ │ │ ├── best_practices.rst
│ │ │ ├── device_tuning.rst
│ │ │ ├── dpsk.md
│ │ │ ├── nsight_profiling.md
│ │ │ ├── perf_tuning.rst
│ │ │ └── verl_profiler_system.md
│ │ ├── preparation/
│ │ │ ├── prepare_data.rst
│ │ │ └── reward_function.rst
│ │ ├── requirements-docs.txt
│ │ ├── sglang_multiturn/
│ │ │ ├── interaction_system.rst
│ │ │ ├── multiturn.rst
│ │ │ ├── sandbox_fusion.rst
│ │ │ └── search_tool_example.rst
│ │ ├── single_controller.rst
│ │ ├── start/
│ │ │ ├── agentic_rl.rst
│ │ │ ├── install.rst
│ │ │ ├── more_resources.rst
│ │ │ ├── multinode.rst
│ │ │ ├── quickstart.rst
│ │ │ └── ray_debug_tutorial.rst
│ │ └── workers/
│ │ ├── fsdp_workers.rst
│ │ ├── megatron_workers.rst
│ │ ├── model_engine.rst
│ │ ├── ray_trainer.rst
│ │ └── sglang_worker.rst
│ ├── examples/
│ │ ├── data_preprocess/
│ │ │ ├── aime2024_multiturn_w_tool.py
│ │ │ ├── dapo_multiturn_w_tool.py
│ │ │ ├── full_hh_rlhf.py
│ │ │ ├── geo3k.py
│ │ │ ├── geo3k_multiturn_w_tool.py
│ │ │ ├── gsm8k.py
│ │ │ ├── gsm8k_multiturn_sft.py
│ │ │ ├── gsm8k_multiturn_w_interaction.py
│ │ │ ├── gsm8k_multiturn_w_tool.py
│ │ │ ├── gsm8k_tool_agent_loop.py
│ │ │ ├── hellaswag.py
│ │ │ ├── math_dataset.py
│ │ │ ├── multiturn.py
│ │ │ └── preprocess_search_r1_dataset.py
│ │ ├── generation/
│ │ │ ├── run_deepseek7b_mutli_node.sh
│ │ │ └── run_deepseek_v2_lite_math.sh
│ │ ├── gmpo_trainer/
│ │ │ ├── README.md
│ │ │ ├── run_qwen2_5-7b_math.sh
│ │ │ ├── test_dapo_7b_math.sh
│ │ │ └── test_dapo_qwen3_30b_math.sh
│ │ ├── gpg_trainer/
│ │ │ ├── gpg.md
│ │ │ ├── run_qwen2-7b_math.sh
│ │ │ └── run_qwen2-7b_math_megatron.sh
│ │ ├── grpo_trainer/
│ │ │ ├── README.md
│ │ │ ├── run_deepseek671b_math_megatron_80gb.sh
│ │ │ ├── run_deepseek671b_math_megatron_96gb.sh
│ │ │ ├── run_deepseek7b_llm.sh
│ │ │ ├── run_deepseek7b_llm_math.sh
│ │ │ ├── run_deepseek7b_llm_math_megatron.sh
│ │ │ ├── run_deepseek7b_llm_seq_balance.sh
│ │ │ ├── run_glm41v_9b.sh
│ │ │ ├── run_gptoss_20b.sh
│ │ │ ├── run_minicpmo2_6.sh
│ │ │ ├── run_mistral13b_skyworkrm_hhrlhf.sh
│ │ │ ├── run_moonlight16b_math_megatron.sh
│ │ │ ├── run_qwen2-7b.sh
│ │ │ ├── run_qwen2-7b_math.sh
│ │ │ ├── run_qwen2-7b_math_megatron.sh
│ │ │ ├── run_qwen2-7b_seq_balance.sh
│ │ │ ├── run_qwen2-7b_seq_balance_math_megatron.sh
│ │ │ ├── run_qwen2-7b_sgl_megatron.sh
│ │ │ ├── run_qwen2_5-3b_gsm8k_grpo_lora.sh
│ │ │ ├── run_qwen2_5-3b_gsm8k_grpo_lora_from_adapter.sh
│ │ │ ├── run_qwen2_5-7b_math_megatron_diff_tp.sh
│ │ │ ├── run_qwen2_5_32b_grpo_npu.sh
│ │ │ ├── run_qwen2_5_7b_grpo_discrete_prof_npu.sh
│ │ │ ├── run_qwen2_5_7b_grpo_e2e_prof_npu.sh
│ │ │ ├── run_qwen2_5_7b_grpo_npu.sh
│ │ │ ├── run_qwen2_5_vl-7b-megatron.sh
│ │ │ ├── run_qwen2_5_vl-7b-sglang.sh
│ │ │ ├── run_qwen2_5_vl-7b.sh
│ │ │ ├── run_qwen2_5_vl-7b_freeze_vision.sh
│ │ │ ├── run_qwen2_5_vl-7b_lora.sh
│ │ │ ├── run_qwen2_5_vl-7b_seq_balance.sh
│ │ │ ├── run_qwen2_5_vl_32b_npu.sh
│ │ │ ├── run_qwen2_5_vl_3b_npu.sh
│ │ │ ├── run_qwen2_5_vl_7b_npu.sh
│ │ │ ├── run_qwen3-235b_megatron_96gb.sh
│ │ │ ├── run_qwen3-32b_npu.sh
│ │ │ ├── run_qwen3-8b.sh
│ │ │ ├── run_qwen3-8b_npu.sh
│ │ │ ├── run_qwen3_8b_grpo_sglang_1k_spmd_npu.sh
│ │ │ ├── run_qwen3_8b_grpo_sglang_32k_spmd_npu.sh
│ │ │ ├── run_qwen3_vl-235b-megatron.sh
│ │ │ ├── run_qwen3_vl-30b-megatron.sh
│ │ │ ├── run_qwen3_vl-8b-megatron.sh
│ │ │ ├── run_qwen3moe-30b_megatron_96gb.sh
│ │ │ └── run_seed_oss_36b.sh
│ │ ├── ppo_trainer/
│ │ │ ├── README.md
│ │ │ ├── run_deepseek7b_llm.sh
│ │ │ ├── run_deepseek7b_llm_modelscope.sh
│ │ │ ├── run_deepseek7b_llm_pfppo.sh
│ │ │ ├── run_deepseek7b_llm_sandbox_fusion.sh
│ │ │ ├── run_deepseek7b_llm_sp2.sh
│ │ │ ├── run_deepseek_full_hh_rlhf.sh
│ │ │ ├── run_deepseek_math_gsm8k_megatron.sh
│ │ │ ├── run_deepseek_math_gsm8k_megatron_nsys.sh
│ │ │ ├── run_gemma.sh
│ │ │ ├── run_moonlight16b_a3b_gsm8k_megatron.sh
│ │ │ ├── run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh
│ │ │ ├── run_qwen2-7b_math_gsm8k_megatron.sh
│ │ │ ├── run_qwen2-7b_rm.sh
│ │ │ ├── run_qwen2-7b_rm_seq_balance.sh
│ │ │ ├── run_qwen2-7b_rm_seq_balance_fused_kernels.sh
│ │ │ ├── run_qwen2-7b_rm_seq_balance_nsys.sh
│ │ │ ├── run_qwen2-7b_seq_balance.sh
│ │ │ ├── run_qwen2-7b_sglang_seq_balance.sh
│ │ │ ├── run_qwen2.5-32b.sh
│ │ │ └── run_qwen3-8b_npu.sh
│ │ ├── ray/
│ │ │ └── tutorial.ipynb
│ │ ├── reinforce_plus_plus_trainer/
│ │ │ ├── run_qwen2-7b_math_rf.sh
│ │ │ └── run_qwen2-7b_math_rf_baseline.sh
│ │ ├── remax_trainer/
│ │ │ ├── run_qwen2.5-3b_seq_balance.sh
│ │ │ └── run_qwen2.5-7b_seq_balance.sh
│ │ ├── rloo_trainer/
│ │ │ └── run_qwen2-7b.sh
│ │ ├── rollout_importance_sampling/
│ │ │ ├── README.md
│ │ │ └── run_with_rollout_is.sh
│ │ ├── sft/
│ │ │ ├── gsm8k/
│ │ │ │ ├── run_deepseek_6b7.sh
│ │ │ │ ├── run_gemma_2b.sh
│ │ │ │ ├── run_gemma_7b.sh
│ │ │ │ ├── run_qwen3_8b_sft_peft_sp2_npu.sh
│ │ │ │ ├── run_qwen_05_peft.sh
│ │ │ │ ├── run_qwen_05_sp2.sh
│ │ │ │ ├── run_qwen_05_sp2_liger.sh
│ │ │ │ └── run_seed_oss_36b_sft.sh
│ │ │ └── multiturn/
│ │ │ └── run_qwen_05_sp2.sh
│ │ ├── sglang_multiturn/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ ├── geo3k_multiturn_grpo.yaml
│ │ │ │ ├── geo3k_multiturn_megatron_grpo.yaml
│ │ │ │ ├── gsm8k_multiturn_grpo.yaml
│ │ │ │ ├── gsm8k_multiturn_grpo_server.yaml
│ │ │ │ ├── gsm8k_multiturn_grpo_w_interaction.yaml
│ │ │ │ ├── gsm8k_multiturn_megatron_grpo.yaml
│ │ │ │ ├── interaction_config/
│ │ │ │ │ └── gsm8k_interaction_config.yaml
│ │ │ │ ├── retool_multiturn_grpo.yaml
│ │ │ │ ├── search_multiturn_grpo.yaml
│ │ │ │ ├── search_multiturn_grpo_one_step_off.yaml
│ │ │ │ └── tool_config/
│ │ │ │ ├── geo3k_tool_config.yaml
│ │ │ │ ├── gsm8k_tool_config.yaml
│ │ │ │ ├── mcp_server.json
│ │ │ │ ├── mcp_tool_config.yaml
│ │ │ │ ├── sandbox_fusion_tool_config.yaml
│ │ │ │ └── search_tool_config.yaml
│ │ │ ├── geo3k/
│ │ │ │ ├── run_qwen2.5-3b_geo3k_multiturn.sh
│ │ │ │ ├── run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh
│ │ │ │ └── run_qwen2.5-3b_megatron_geo3k_multiturn.sh
│ │ │ ├── run_qwen0.5b_gsm8k_multiturn_curriculum.sh
│ │ │ ├── run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn_4xgpu_server.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn_server.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn_vllm_fsdp.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh
│ │ │ ├── run_qwen2.5-3b_megatron_gsm8k_multiturn.sh
│ │ │ ├── run_qwen3-4b_gsm8k_multiturn.sh
│ │ │ ├── run_qwen3_4b_dapo_multiturn.sh
│ │ │ └── search_r1_like/
│ │ │ ├── local_dense_retriever/
│ │ │ │ ├── download.py
│ │ │ │ └── retrieval_server.py
│ │ │ └── run_qwen2.5-3b_instruct_search_multiturn.sh
│ │ ├── skypilot/
│ │ │ ├── README.md
│ │ │ ├── verl-grpo.yaml
│ │ │ ├── verl-multiturn-tools.yaml
│ │ │ └── verl-ppo.yaml
│ │ ├── slurm/
│ │ │ └── ray_on_slurm.slurm
│ │ ├── split_placement/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ └── ppo_trainer_split.yaml
│ │ │ ├── main_ppo_split.py
│ │ │ ├── run_deepseek7b_llm.sh
│ │ │ └── split_monkey_patch.py
│ │ ├── tuning/
│ │ │ ├── 0.5b/
│ │ │ │ └── qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh
│ │ │ ├── 1.5b/
│ │ │ │ └── qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh
│ │ │ ├── 14b/
│ │ │ │ ├── qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh
│ │ │ │ └── qwen2_14b_grpo_4_h800_fsdp_vllm.sh
│ │ │ ├── 32b/
│ │ │ │ ├── qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh
│ │ │ │ └── qwen2_32B_grpo_8_h20_megatron_vllm.sh
│ │ │ ├── 3b/
│ │ │ │ └── qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh
│ │ │ ├── 70b/
│ │ │ │ ├── qwen2-70b_grpo_32_h20_fsdp_vllm.sh
│ │ │ │ ├── qwen2-70b_grpo_32_h800_fsdp_vllm.sh
│ │ │ │ └── qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh
│ │ │ └── 7b/
│ │ │ ├── qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh
│ │ │ └── qwen2-7b_grpo_2_h800_fsdp_vllm.sh
│ │ └── tutorial/
│ │ └── agent_loop_get_started/
│ │ ├── agent_loop_tutorial.ipynb
│ │ └── sandbox.py
│ ├── init_ray.sh
│ ├── init_ray_cluster.sh
│ ├── pyproject.toml
│ ├── recipe/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── char_count/
│ │ │ ├── README.md
│ │ │ ├── create_dataset.py
│ │ │ ├── reward_function.py
│ │ │ ├── train_grpo.sh
│ │ │ └── train_sft.sh
│ │ ├── collabllm/
│ │ │ ├── README.md
│ │ │ ├── collabllm_agent_loop.py
│ │ │ ├── collabllm_interation.py
│ │ │ ├── config/
│ │ │ │ ├── agent.yaml
│ │ │ │ └── collabllm_interaction_config.yaml
│ │ │ ├── metrics/
│ │ │ │ ├── accuracy.py
│ │ │ │ ├── bleu_score.py
│ │ │ │ ├── interactivity.py
│ │ │ │ ├── pass_rate.py
│ │ │ │ └── token_amount.py
│ │ │ ├── process_dataset.py
│ │ │ ├── reward_function.py
│ │ │ ├── train_rl_collabllm.sh
│ │ │ ├── train_sft_collabllm.sh
│ │ │ └── utils.py
│ │ ├── dapo/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ ├── dapo_megatron_trainer.yaml
│ │ │ │ └── dapo_trainer.yaml
│ │ │ ├── dapo_ray_trainer.py
│ │ │ ├── main_dapo.py
│ │ │ ├── prepare_dapo_data.sh
│ │ │ ├── run_dapo_early_qwen2.5_32b.sh
│ │ │ ├── run_dapo_qwen2.5_32b.sh
│ │ │ ├── run_dapo_qwen2.5_32b_npu.sh
│ │ │ ├── run_dapo_qwen2.5_32b_rollout_is.sh
│ │ │ ├── run_dapo_qwen2.5_7b_npu.sh
│ │ │ ├── run_dapo_qwen3_14b_base_npu.sh
│ │ │ ├── run_dapo_qwen3_8b_base_npu.sh
│ │ │ ├── run_dapo_qwen3_moe_30b_base_fsdp_npu.sh
│ │ │ ├── run_dapo_qwen3_moe_30b_megatron_npu.sh
│ │ │ ├── run_dapo_wo_ds_qwen2.5_32b.sh
│ │ │ ├── runtime_env.yaml
│ │ │ ├── test_dapo_7b.sh
│ │ │ ├── test_dapo_7b_math.sh
│ │ │ ├── test_dapo_7b_math_lora.sh
│ │ │ ├── test_dapo_7b_math_megatron.sh
│ │ │ ├── test_dapo_dspk_671b_megatron_96gb.sh
│ │ │ ├── test_dapo_glm_air_megatron.sh
│ │ │ ├── test_dapo_qwen3_30b_math.sh
│ │ │ └── test_dapo_qwen3_30b_math_single_node.sh
│ │ ├── deepeyes/
│ │ │ ├── README.md
│ │ │ ├── configs/
│ │ │ │ ├── deepeyes_multiturn_grpo.yaml
│ │ │ │ └── image_zoom_in_tool_config.yaml
│ │ │ ├── deepeyes.py
│ │ │ └── run_deepeyes_grpo.sh
│ │ ├── entropy/
│ │ │ ├── 32b_clip_cov.sh
│ │ │ ├── 32b_kl_cov.sh
│ │ │ ├── 32b_kl_cov_mininbsz.sh
│ │ │ ├── 7b_clip_cov.sh
│ │ │ ├── 7b_kl_cov.sh
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ └── entropy_trainer.yaml
│ │ │ ├── entropy_ray_trainer.py
│ │ │ ├── main_entropy.py
│ │ │ ├── reward.py
│ │ │ └── reward_score/
│ │ │ ├── __init__.py
│ │ │ └── entropy_math/
│ │ │ ├── __init__.py
│ │ │ ├── grader.py
│ │ │ └── math_normalize.py
│ │ ├── fapo/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ └── rm_config.yaml
│ │ │ ├── prepare_fapo_data.py
│ │ │ ├── reward_fn_genrm.py
│ │ │ ├── reward_fn_reasoning.py
│ │ │ ├── reward_fn_reasoning_remote.py
│ │ │ ├── run_baseline_32b.sh
│ │ │ ├── run_baseline_7b.sh
│ │ │ ├── run_fapo_32b.sh
│ │ │ ├── run_fapo_32b_remote.sh
│ │ │ ├── run_fapo_7b.sh
│ │ │ ├── run_fapo_7b_remote.sh
│ │ │ ├── run_fapo_genrm_train.sh
│ │ │ └── runtime_env.yaml
│ │ ├── fully_async_policy/
│ │ │ ├── README.md
│ │ │ ├── README_zh.md
│ │ │ ├── agent_loop/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── agent_loop.py
│ │ │ │ └── partial_single_turn_agent_loop.py
│ │ │ ├── config/
│ │ │ │ ├── fully_async_ppo_megatron_trainer.yaml
│ │ │ │ └── fully_async_ppo_trainer.yaml
│ │ │ ├── detach_utils.py
│ │ │ ├── fsdp2_utils.py
│ │ │ ├── fsdp_workers.py
│ │ │ ├── fully_async_main.py
│ │ │ ├── fully_async_rollouter.py
│ │ │ ├── fully_async_trainer.py
│ │ │ ├── megatron_worker.py
│ │ │ ├── message_queue.py
│ │ │ ├── param_sync.py
│ │ │ ├── ray_trainer.py
│ │ │ ├── shell/
│ │ │ │ ├── dapo_7b_math_fsdp2_16_16.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_32_32.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_4_12.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_4_4.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_64_64.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_64_64_mis.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_8_8.sh
│ │ │ │ ├── geo3k_qwen25vl_7b_megatron_4_4.sh
│ │ │ │ └── runtime_env.yaml
│ │ │ ├── unittest/
│ │ │ │ └── simple_streaming_demo.py
│ │ │ └── vllm_rollout/
│ │ │ ├── __init__.py
│ │ │ └── vllm_async_server.py
│ │ ├── genrm_remote/
│ │ │ ├── README.md
│ │ │ ├── reward_function.py
│ │ │ └── run_genrm_remote.sh
│ │ ├── gspo/
│ │ │ ├── test_gspo_3b_math.sh
│ │ │ ├── test_gspo_3b_math_slurm.sh
│ │ │ └── test_gspo_qwen30b_a3b_ep.sh
│ │ ├── infigui-g1/
│ │ │ ├── README.md
│ │ │ ├── reward_fn.py
│ │ │ ├── run_3b.sh
│ │ │ └── run_7b.sh
│ │ ├── langgraph_agent/
│ │ │ ├── __init__.py
│ │ │ ├── chat_model.py
│ │ │ ├── example/
│ │ │ │ ├── README.md
│ │ │ │ ├── agent.yaml
│ │ │ │ ├── create_dataset.py
│ │ │ │ ├── math_expression.py
│ │ │ │ ├── run_gpt_oss_20b_bf16.sh
│ │ │ │ └── run_qwen2.5_3b.sh
│ │ │ ├── react_agent_loop.py
│ │ │ └── test_react_agent_loop.py
│ │ ├── minicpmo/
│ │ │ └── rl_dataset.py
│ │ ├── one_step_off_policy/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ ├── one_step_off_ppo_megatron_trainer.yaml
│ │ │ │ └── one_step_off_ppo_trainer.yaml
│ │ │ ├── dapo_7b_math_fsdp2_4_12.sh
│ │ │ ├── dapo_7b_math_fsdp2_colocate.sh
│ │ │ ├── dapo_7b_math_fsdp2_sglang_4_12.sh
│ │ │ ├── dapo_7b_math_fsdp2_sglang_colocate.sh
│ │ │ ├── dapo_7b_math_megatron_4_12.sh
│ │ │ ├── dapo_7b_math_megatron_colocate.sh
│ │ │ ├── distributed_util.py
│ │ │ ├── fsdp_workers.py
│ │ │ ├── grpo_0.6b_gsm8k_fsdp2_2_6.sh
│ │ │ ├── grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh
│ │ │ ├── grpo_3b_gsm8k_fsdp2_2_6.sh
│ │ │ ├── main_ppo.py
│ │ │ ├── megatron_workers.py
│ │ │ ├── ray_trainer.py
│ │ │ ├── sglang_sharding_manager.py
│ │ │ ├── utils.py
│ │ │ └── vllm_sharding_manager.py
│ │ ├── onpolicy_distill/
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ └── onpolicy_distill_trainer.yaml
│ │ │ ├── main_onpolicy_distill.py
│ │ │ ├── onpolicy_distill_trainer.py
│ │ │ └── run_qwen3_distill.sh
│ │ ├── open_math_reasoning/
│ │ │ ├── README.md
│ │ │ ├── compute_score.py
│ │ │ ├── prepare_eval_dataset.py
│ │ │ ├── prepare_nvidia-OpenMathReasoning_sft.py
│ │ │ ├── run_eval.sh
│ │ │ ├── run_generation.sh
│ │ │ └── run_sft_qwen3_8b.sh
│ │ ├── prime/
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ └── prime_trainer.yaml
│ │ │ ├── main_prime.py
│ │ │ ├── prime_core_algos.py
│ │ │ ├── prime_dp_rm.py
│ │ │ ├── prime_fsdp_workers.py
│ │ │ ├── prime_ray_trainer.py
│ │ │ ├── run_prime_qwen.sh
│ │ │ └── run_prime_qwen_code.sh
│ │ ├── r1/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ └── evaluation.yaml
│ │ │ ├── data_process.py
│ │ │ ├── main_eval.py
│ │ │ ├── reward_score.py
│ │ │ ├── run_r1_distill_qwen.sh
│ │ │ └── tasks/
│ │ │ ├── __init__.py
│ │ │ ├── gpqa.py
│ │ │ ├── livecodebench.py
│ │ │ └── math_reward.py
│ │ ├── retool/
│ │ │ ├── README.md
│ │ │ ├── retool.py
│ │ │ ├── retool_sft_preprocess.py
│ │ │ ├── run_gpt_oss_ppo.sh
│ │ │ ├── run_qwen2-32b_dapo.sh
│ │ │ ├── run_qwen2-32b_ppo.sh
│ │ │ ├── run_qwen2-32b_sft.sh
│ │ │ ├── run_qwen2_7b_dapo.sh
│ │ │ ├── run_qwen2_7b_sft.sh
│ │ │ ├── run_qwen2_7b_sft_npu.sh
│ │ │ └── sandbox_fusion_tool_config.yaml
│ │ ├── spin/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ └── spin_trainer.yaml
│ │ │ ├── core_algos.py
│ │ │ ├── dp_actor.py
│ │ │ ├── fsdp_workers.py
│ │ │ ├── main_spin.py
│ │ │ ├── run_spin.sh
│ │ │ ├── spin_trainer.py
│ │ │ └── utils.py
│ │ ├── sppo/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ └── sppo_trainer.yaml
│ │ │ ├── config.py
│ │ │ ├── dp_actor.py
│ │ │ ├── main_sppo.py
│ │ │ ├── run_qwen2.5-7b_rm.sh
│ │ │ ├── sppo_ray_trainer.py
│ │ │ └── sppo_worker.py
│ │ └── transfer_queue/
│ │ ├── agent_loop.py
│ │ ├── config/
│ │ │ └── transfer_queue_ppo_trainer.yaml
│ │ ├── main_ppo.py
│ │ ├── ray_trainer.py
│ │ └── run_qwen3-8b_transferqueue_npu.sh
│ ├── requirements-cuda.txt
│ ├── requirements-npu.txt
│ ├── requirements.txt
│ ├── requirements_sglang.txt
│ ├── requirements_transferqueue.txt
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── converter_hf_to_mcore.py
│ │ ├── diagnose.py
│ │ ├── generate_trainer_config.sh
│ │ ├── init_random_model.py
│ │ ├── install_vllm_sglang_mcore.sh
│ │ ├── legacy_model_merger.py
│ │ ├── print_cfg.py
│ │ └── rollout_viewer.py
│ ├── setup.py
│ ├── tests/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── experimental/
│ │ │ ├── agent_loop/
│ │ │ │ ├── agent_utils.py
│ │ │ │ ├── qwen_vl_tool_chat_template.jinja2
│ │ │ │ ├── test_agent_loop_reward.py
│ │ │ │ ├── test_agent_loop_reward_model.py
│ │ │ │ ├── test_basic_agent_loop.py
│ │ │ │ ├── test_gpt_oss_tool_parser.py
│ │ │ │ ├── test_multi_modal.py
│ │ │ │ └── test_standalone_rollout.py
│ │ │ └── reward/
│ │ │ ├── reward_fn.py
│ │ │ ├── test_agent_loop_reward_manager.py
│ │ │ └── test_reward_model.py
│ │ ├── interactions/
│ │ │ ├── __init__.py
│ │ │ ├── test_gsm8k_interaction.py
│ │ │ └── test_interaction_registry.py
│ │ ├── kill_github_tests.sh
│ │ ├── models/
│ │ │ ├── test_engine.py
│ │ │ ├── test_transformer.py
│ │ │ └── test_transformers_ulysses.py
│ │ ├── single_controller/
│ │ │ ├── __init__.py
│ │ │ ├── base/
│ │ │ │ └── test_decorator.py
│ │ │ ├── check_worker_alive/
│ │ │ │ └── main.py
│ │ │ ├── detached_worker/
│ │ │ │ ├── README.md
│ │ │ │ ├── client.py
│ │ │ │ ├── run.sh
│ │ │ │ └── server.py
│ │ │ ├── test_auto_padding_on_cpu.py
│ │ │ ├── test_colocated_workers.py
│ │ │ ├── test_colocated_workers_fused.py
│ │ │ ├── test_data_transfer.py
│ │ │ ├── test_decorator_on_cpu.py
│ │ │ ├── test_device_mesh_register.py
│ │ │ ├── test_driverfunc_to_worker.py
│ │ │ ├── test_fused_workers_on_cpu.py
│ │ │ ├── test_high_level_scheduling_api.py
│ │ │ ├── test_nested_worker.py
│ │ │ ├── test_ray_collectives.py
│ │ │ ├── test_ray_local_envs_on_cpu.py
│ │ │ ├── test_ray_utils_on_cpu.py
│ │ │ ├── test_rvdz.py
│ │ │ ├── test_worker_group_basics.py
│ │ │ └── test_worker_group_torch.py
│ │ ├── special_distributed/
│ │ │ ├── README.md
│ │ │ ├── run_all.sh
│ │ │ ├── test_fsdp_ckpt.py
│ │ │ ├── test_mcore_config_converter.py
│ │ │ └── test_tensor_dict.py
│ │ ├── special_e2e/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── check_custom_rwd_fn.py
│ │ │ ├── check_results.py
│ │ │ ├── envs/
│ │ │ │ ├── __init__.py
│ │ │ │ └── digit_completion/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── task.py
│ │ │ │ └── tokenizer.py
│ │ │ ├── generation/
│ │ │ │ ├── run_gen_qwen05.sh
│ │ │ │ └── run_gen_qwen05_server.sh
│ │ │ ├── ppo_trainer/
│ │ │ │ ├── expert_parallel/
│ │ │ │ │ └── qwen2moe_minimal.json
│ │ │ │ ├── run_function_reward.sh
│ │ │ │ ├── run_model_reward.sh
│ │ │ │ ├── run_single_gpu.sh
│ │ │ │ └── run_single_gpu_with_engine.sh
│ │ │ ├── run_dapo.sh
│ │ │ ├── run_fully_async_policy.sh
│ │ │ ├── run_genrm_remote.sh
│ │ │ ├── run_geo3k_fsdp_sgl_multiturn_w_tool.sh
│ │ │ ├── run_grpo_lora_with_merge.sh
│ │ │ ├── run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh
│ │ │ ├── run_gsm8k_fsdp_sgl_multiturn_w_tool.sh
│ │ │ ├── run_one_step_off_policy.sh
│ │ │ ├── run_ppo_trainer_megatron.sh
│ │ │ ├── run_prime.sh
│ │ │ ├── run_r1_distill_qwen_aime24_eval.sh
│ │ │ ├── run_spin.sh
│ │ │ ├── run_sppo.sh
│ │ │ ├── run_test.sh
│ │ │ └── sft/
│ │ │ ├── compare_sft_engine_results.py
│ │ │ ├── run_sft.sh
│ │ │ ├── run_sft_engine_gsm8k.sh
│ │ │ ├── test_sft_engine_all.sh
│ │ │ └── test_sp_loss_match.py
│ │ ├── special_npu/
│ │ │ ├── run_qwen2_5_05b_dapo.sh
│ │ │ ├── run_qwen2_5_05b_grpo.sh
│ │ │ ├── run_qwen2_5_05b_grpo_mindspeed.sh
│ │ │ ├── run_qwen2_5_05b_sft_peft_sp2.sh
│ │ │ ├── run_qwen2_5_vl_3b_npu.sh
│ │ │ └── run_qwen3_06b_ppo.sh
│ │ ├── special_sanity/
│ │ │ ├── check_api_docs.py
│ │ │ ├── check_dataproto_usage.py
│ │ │ ├── check_device_api_usage.py
│ │ │ ├── check_docs_time_info.py
│ │ │ ├── check_docstrings.py
│ │ │ ├── check_license.py
│ │ │ ├── check_pr_description.py
│ │ │ ├── check_pr_title.py
│ │ │ ├── test_config_docs.py
│ │ │ ├── test_import.py
│ │ │ ├── type_coverage_check.py
│ │ │ ├── validate_imported_docs.py
│ │ │ └── validate_structure.py
│ │ ├── special_standalone/
│ │ │ ├── README.md
│ │ │ └── test_memory_buffers.py
│ │ ├── test_base_config_on_cpu.py
│ │ ├── test_protocol_on_cpu.py
│ │ ├── test_protocol_v2_on_cpu.py
│ │ ├── trainer/
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── legacy_ppo_megatron_trainer.yaml
│ │ │ │ ├── legacy_ppo_trainer.yaml
│ │ │ │ ├── test_algo_config_on_cpu.py
│ │ │ │ └── test_legacy_config_on_cpu.py
│ │ │ └── ppo/
│ │ │ ├── __init__.py
│ │ │ ├── test_core_algos_on_cpu.py
│ │ │ ├── test_metric_utils_on_cpu.py
│ │ │ ├── test_rollout_is.py
│ │ │ └── test_rollout_is_integration.py
│ │ ├── utils/
│ │ │ ├── _test_module.py
│ │ │ ├── dataset/
│ │ │ │ ├── test_create_rl_sampler_on_cpu.py
│ │ │ │ ├── test_multiturn_sft_dataset_on_cpu.py
│ │ │ │ ├── test_rl_collate_fn_on_cpu.py
│ │ │ │ ├── test_rl_dataset_on_cpu.py
│ │ │ │ └── test_sft_dataset_on_cpu.py
│ │ │ ├── debug/
│ │ │ │ └── test_metrics.py
│ │ │ ├── megatron/
│ │ │ │ └── test_pipeline_parallel.py
│ │ │ ├── reward_score/
│ │ │ │ ├── reward_score/
│ │ │ │ │ └── test_sandbox_fusion_on_cpu.py
│ │ │ │ └── test_sandbox_on_cpu.py
│ │ │ ├── test_activation_offload.py
│ │ │ ├── test_config_on_cpu.py
│ │ │ ├── test_flops_counter.py
│ │ │ ├── test_fs_on_cpu.py
│ │ │ ├── test_groupwise.py
│ │ │ ├── test_import_utils_on_cpu.py
│ │ │ ├── test_linear_cross_entropy.py
│ │ │ ├── test_mlflow_key_sanitization.py
│ │ │ ├── test_model_on_cpu.py
│ │ │ ├── test_nvtx_profile.py
│ │ │ ├── test_rollout_skip_on_cpu.py
│ │ │ ├── test_rollout_trace_on_cpu.py
│ │ │ ├── test_seqlen_balancing.py
│ │ │ ├── test_special_linear_cross_entropy_tp.py
│ │ │ ├── test_special_mstx_profile.py
│ │ │ ├── test_temp_env_on_cpu.py
│ │ │ ├── test_timeout_decorator_cpu.py
│ │ │ └── test_torch_functional.py
│ │ └── workers/
│ │ ├── actor/
│ │ │ └── test_special_dp_actor.py
│ │ ├── config/
│ │ │ ├── test_actor_config_on_cpu.py
│ │ │ ├── test_critic_config_on_cpu.py
│ │ │ ├── test_engine_config_on_cpu.py
│ │ │ └── test_optim_config_on_cpu.py
│ │ ├── critic/
│ │ │ └── test_special_dp_critic.py
│ │ ├── reward_manager/
│ │ │ └── test_registry_on_cpu.py
│ │ ├── rollout/
│ │ │ ├── perf/
│ │ │ │ └── vllm_async_rollout.py
│ │ │ ├── resource/
│ │ │ │ └── tool_configs/
│ │ │ │ ├── mcp_server.json
│ │ │ │ ├── mcp_tool_config
│ │ │ │ ├── sandbox_fusion_tool_config
│ │ │ │ └── search_tool_config
│ │ │ ├── rollout_sglang/
│ │ │ │ └── test_http_server_engine.py
│ │ │ ├── rollout_vllm/
│ │ │ │ ├── run_fsdp_vllm.py
│ │ │ │ ├── test_vllm_model_rope_scaling.py
│ │ │ │ └── test_vllm_spmd.py
│ │ │ ├── test_hf_rollout.py
│ │ │ ├── test_sglang_async_rollout_mcp_tools.py
│ │ │ ├── test_sglang_async_rollout_multimodal_delta.py
│ │ │ ├── test_sglang_async_rollout_search_tools.py
│ │ │ ├── test_sglang_async_rollout_sf_tools.py
│ │ │ ├── test_sglang_async_rollout_w_interaction.py
│ │ │ ├── test_sglang_async_rollout_w_tools.py
│ │ │ ├── test_sglang_async_rollout_w_tools_token_out.py
│ │ │ ├── test_sglang_multi_interaction.py
│ │ │ ├── test_sglang_rollout_sharding_manager.py
│ │ │ ├── test_sglang_spmd.py
│ │ │ └── utils_sglang.py
│ │ ├── test_fsdp_attn_implementation.py
│ │ └── test_fsdp_workers.py
│ └── verl/
│ ├── __init__.py
│ ├── base_config.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ ├── agent_loop/
│ │ │ ├── __init__.py
│ │ │ ├── agent_loop.py
│ │ │ ├── single_turn_agent_loop.py
│ │ │ ├── tool_agent_loop.py
│ │ │ ├── tool_parser.py
│ │ │ └── utils.py
│ │ ├── dataset/
│ │ │ ├── __init__.py
│ │ │ └── sampler.py
│ │ ├── dynamic_dataset/
│ │ │ ├── __init__.py
│ │ │ └── dynamicgen_dataset.py
│ │ └── reward/
│ │ ├── __init__.py
│ │ ├── reward_loop/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── dapo.py
│ │ │ ├── naive.py
│ │ │ └── registry.py
│ │ ├── reward_manager.py
│ │ ├── reward_model.py
│ │ └── router/
│ │ ├── naive_router.py
│ │ └── sglang_router.py
│ ├── interactions/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── gsm8k_interaction.py
│ │ ├── utils/
│ │ │ ├── __init__.py
│ │ │ └── interaction_registry.py
│ │ └── weather_interaction.py
│ ├── model_merger/
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── base_model_merger.py
│ │ ├── fsdp_model_merger.py
│ │ └── megatron_model_merger.py
│ ├── models/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── llama/
│ │ │ ├── __init__.py
│ │ │ └── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── llama_loader.py
│ │ │ │ ├── llama_loader_depracated.py
│ │ │ │ └── llama_saver.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── parallel_attention.py
│ │ │ │ ├── parallel_decoder.py
│ │ │ │ ├── parallel_linear.py
│ │ │ │ ├── parallel_mlp.py
│ │ │ │ └── parallel_rmsnorm.py
│ │ │ └── modeling_llama_megatron.py
│ │ ├── mcore/
│ │ │ ├── __init__.py
│ │ │ ├── config_converter.py
│ │ │ ├── loader.py
│ │ │ ├── mbridge.py
│ │ │ ├── model_forward.py
│ │ │ ├── model_forward_1f1b_overlap.py
│ │ │ ├── model_forward_fused.py
│ │ │ ├── model_initializer.py
│ │ │ ├── patch_v012.py
│ │ │ ├── qwen2_5_vl/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── model.py
│ │ │ │ ├── rope_utils.py
│ │ │ │ ├── vision_config.py
│ │ │ │ ├── vision_model.py
│ │ │ │ └── vision_transformer_block.py
│ │ │ ├── readme.md
│ │ │ ├── registry.py
│ │ │ ├── saver.py
│ │ │ ├── util.py
│ │ │ └── weight_converter.py
│ │ ├── qwen2/
│ │ │ ├── __init__.py
│ │ │ └── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── qwen2_loader.py
│ │ │ │ ├── qwen2_loader_depracated.py
│ │ │ │ └── qwen2_saver.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── parallel_attention.py
│ │ │ │ ├── parallel_decoder.py
│ │ │ │ ├── parallel_linear.py
│ │ │ │ ├── parallel_mlp.py
│ │ │ │ └── parallel_rmsnorm.py
│ │ │ └── modeling_qwen2_megatron.py
│ │ ├── registry.py
│ │ ├── transformers/
│ │ │ ├── __init__.py
│ │ │ ├── apertus.py
│ │ │ ├── dense_common.py
│ │ │ ├── glm4v.py
│ │ │ ├── kimi_vl.py
│ │ │ ├── llama.py
│ │ │ ├── monkey_patch.py
│ │ │ ├── npu_patch.py
│ │ │ ├── qwen2.py
│ │ │ ├── qwen2_vl.py
│ │ │ └── qwen3_vl.py
│ │ └── weight_loader_registry.py
│ ├── protocol.py
│ ├── py.typed
│ ├── single_controller/
│ │ ├── __init__.py
│ │ ├── base/
│ │ │ ├── __init__.py
│ │ │ ├── decorator.py
│ │ │ ├── worker.py
│ │ │ └── worker_group.py
│ │ └── ray/
│ │ ├── __init__.py
│ │ └── base.py
│ ├── third_party/
│ │ ├── __init__.py
│ │ ├── sglang/
│ │ │ ├── __init__.py
│ │ │ └── parallel_state.py
│ │ ├── torch/
│ │ │ ├── __init__.py
│ │ │ └── distributed/
│ │ │ ├── __init__.py
│ │ │ ├── _state_dict_utils.py
│ │ │ └── checkpoint/
│ │ │ ├── __init__.py
│ │ │ └── state_dict.py
│ │ └── vllm/
│ │ └── __init__.py
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── base_tool.py
│ │ ├── geo3k_tool.py
│ │ ├── gsm8k_tool.py
│ │ ├── image_zoom_in_tool.py
│ │ ├── mcp_base_tool.py
│ │ ├── mcp_search_tool.py
│ │ ├── sandbox_fusion_tools.py
│ │ ├── schemas.py
│ │ ├── search_tool.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── mcp_clients/
│ │ │ ├── McpClientManager.py
│ │ │ └── utils.py
│ │ ├── search_r1_like_utils.py
│ │ └── tool_registry.py
│ ├── trainer/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ ├── __init__.py
│ │ │ ├── _generated_ppo_megatron_trainer.yaml
│ │ │ ├── _generated_ppo_trainer.yaml
│ │ │ ├── actor/
│ │ │ │ ├── actor.yaml
│ │ │ │ ├── dp_actor.yaml
│ │ │ │ └── megatron_actor.yaml
│ │ │ ├── algorithm.py
│ │ │ ├── config.py
│ │ │ ├── critic/
│ │ │ │ ├── critic.yaml
│ │ │ │ ├── dp_critic.yaml
│ │ │ │ └── megatron_critic.yaml
│ │ │ ├── data/
│ │ │ │ └── legacy_data.yaml
│ │ │ ├── engine/
│ │ │ │ ├── fsdp.yaml
│ │ │ │ └── megatron.yaml
│ │ │ ├── evaluation.yaml
│ │ │ ├── generation.yaml
│ │ │ ├── model/
│ │ │ │ └── hf_model.yaml
│ │ │ ├── npu_profile/
│ │ │ │ └── npu_profile.yaml
│ │ │ ├── optim/
│ │ │ │ ├── fsdp.yaml
│ │ │ │ └── megatron.yaml
│ │ │ ├── ppo_megatron_trainer.yaml
│ │ │ ├── ppo_trainer.yaml
│ │ │ ├── ref/
│ │ │ │ ├── dp_ref.yaml
│ │ │ │ ├── megatron_ref.yaml
│ │ │ │ └── ref.yaml
│ │ │ ├── reward_model/
│ │ │ │ ├── dp_reward_model.yaml
│ │ │ │ ├── megatron_reward_model.yaml
│ │ │ │ └── reward_model.yaml
│ │ │ ├── rollout/
│ │ │ │ └── rollout.yaml
│ │ │ ├── sft_trainer.yaml
│ │ │ └── sft_trainer_engine.yaml
│ │ ├── constants_ppo.py
│ │ ├── fsdp_sft_trainer.py
│ │ ├── main_eval.py
│ │ ├── main_generation.py
│ │ ├── main_generation_server.py
│ │ ├── main_ppo.py
│ │ ├── ppo/
│ │ │ ├── __init__.py
│ │ │ ├── core_algos.py
│ │ │ ├── metric_utils.py
│ │ │ ├── mismatch_helper.py
│ │ │ ├── ray_trainer.py
│ │ │ ├── reward.py
│ │ │ └── utils.py
│ │ ├── runtime_env.yaml
│ │ └── sft_trainer.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── activation_offload.py
│ │ ├── attention_utils.py
│ │ ├── checkpoint/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_handler.py
│ │ │ ├── checkpoint_manager.py
│ │ │ ├── fsdp_checkpoint_manager.py
│ │ │ └── megatron_checkpoint_manager.py
│ │ ├── config.py
│ │ ├── dataset/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── dataset_utils.py
│ │ │ ├── multiturn_sft_dataset.py
│ │ │ ├── onerec_dataset.py
│ │ │ ├── rl_dataset.py
│ │ │ ├── rm_dataset.py
│ │ │ ├── sft_dataset.py
│ │ │ └── vision_utils.py
│ │ ├── debug/
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ ├── performance.py
│ │ │ └── trajectory_tracker.py
│ │ ├── device.py
│ │ ├── distributed.py
│ │ ├── experimental/
│ │ │ ├── __init__.py
│ │ │ └── torch_functional.py
│ │ ├── flops_counter.py
│ │ ├── fs.py
│ │ ├── fsdp_utils.py
│ │ ├── groupwise.py
│ │ ├── hdfs_io.py
│ │ ├── import_utils.py
│ │ ├── kernel/
│ │ │ ├── __init__.py
│ │ │ ├── kernels.py
│ │ │ └── linear_cross_entropy.py
│ │ ├── logger/
│ │ │ ├── __init__.py
│ │ │ └── aggregate_logger.py
│ │ ├── logging_utils.py
│ │ ├── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── dist_checkpointing.py
│ │ │ ├── memory.py
│ │ │ ├── optimizer.py
│ │ │ ├── pipeline_parallel.py
│ │ │ ├── sequence_parallel.py
│ │ │ └── tensor_parallel.py
│ │ ├── megatron_utils.py
│ │ ├── memory_buffer.py
│ │ ├── memory_utils.py
│ │ ├── metric/
│ │ │ ├── __init__.py
│ │ │ └── utils.py
│ │ ├── model.py
│ │ ├── net_utils.py
│ │ ├── npu_utils.py
│ │ ├── profiler/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── empty_annotations.py
│ │ │ ├── mstx_profile.py
│ │ │ ├── nvtx_profile.py
│ │ │ ├── performance.py
│ │ │ └── profile.py
│ │ ├── py_functional.py
│ │ ├── ray_utils.py
│ │ ├── rendezvous/
│ │ │ ├── __init__.py
│ │ │ └── ray_backend.py
│ │ ├── reward_score/
│ │ │ ├── __init__.py
│ │ │ ├── geo3k.py
│ │ │ ├── gsm8k.py
│ │ │ ├── math_batch.py
│ │ │ ├── math_dapo.py
│ │ │ ├── math_reward.py
│ │ │ ├── math_verify.py
│ │ │ ├── prime_code/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── testing_util.py
│ │ │ │ └── utils.py
│ │ │ ├── prime_math/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── grader.py
│ │ │ │ └── math_normalize.py
│ │ │ ├── sandbox_fusion/
│ │ │ │ ├── __init__.py
│ │ │ │ └── utils.py
│ │ │ └── search_r1_like_qa_em.py
│ │ ├── rollout_skip.py
│ │ ├── rollout_trace.py
│ │ ├── seqlen_balancing.py
│ │ ├── tensordict_utils.py
│ │ ├── tokenizer.py
│ │ ├── torch_dtypes.py
│ │ ├── torch_functional.py
│ │ ├── tracking.py
│ │ ├── transferqueue_utils.py
│ │ ├── transformers_compat.py
│ │ ├── ulysses.py
│ │ └── vllm/
│ │ ├── __init__.py
│ │ ├── patch.py
│ │ └── utils.py
│ ├── version/
│ │ └── version
│ └── workers/
│ ├── __init__.py
│ ├── actor/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── dp_actor.py
│ │ └── megatron_actor.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── actor.py
│ │ ├── critic.py
│ │ ├── engine.py
│ │ ├── model.py
│ │ ├── optimizer.py
│ │ ├── reward_model.py
│ │ └── rollout.py
│ ├── critic/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── dp_critic.py
│ │ └── megatron_critic.py
│ ├── engine/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── fsdp/
│ │ │ ├── __init__.py
│ │ │ ├── transformer_impl.py
│ │ │ └── utils.py
│ │ ├── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── transformer_impl.py
│ │ │ └── utils.py
│ │ ├── mindspeed/
│ │ │ ├── __init__.py
│ │ │ └── transformer_impl.py
│ │ └── utils.py
│ ├── fsdp_workers.py
│ ├── megatron_workers.py
│ ├── reward_manager/
│ │ ├── __init__.py
│ │ ├── abstract.py
│ │ ├── batch.py
│ │ ├── dapo.py
│ │ ├── naive.py
│ │ ├── prime.py
│ │ └── registry.py
│ ├── reward_model/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── megatron/
│ │ ├── __init__.py
│ │ └── reward_model.py
│ ├── roles/
│ │ ├── __init__.py
│ │ ├── actor.py
│ │ ├── critic.py
│ │ ├── hybrid_engine.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── losses.py
│ │ └── padding.py
│ ├── rollout/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── hf_rollout.py
│ │ ├── naive/
│ │ │ ├── __init__.py
│ │ │ └── naive_rollout.py
│ │ ├── replica.py
│ │ ├── schemas.py
│ │ ├── sglang_rollout/
│ │ │ ├── __init__.py
│ │ │ ├── async_sglang_server.py
│ │ │ ├── http_server_engine.py
│ │ │ ├── sglang_rollout.py
│ │ │ └── utils.py
│ │ ├── tokenizer.py
│ │ ├── utils.py
│ │ └── vllm_rollout/
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── vllm_async_server.py
│ │ └── vllm_rollout_spmd.py
│ └── sharding_manager/
│ ├── __init__.py
│ ├── base.py
│ ├── fsdp_sglang.py
│ ├── fsdp_ulysses.py
│ ├── fsdp_vllm.py
│ ├── megatron_sglang.py
│ └── megatron_vllm.py
└── verl_rl/
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── README_ORIGINAL.md
├── deploy_env.sh
├── docker/
│ ├── Apptainerfile.rocm
│ ├── Dockerfile.extention.awsefa
│ ├── Dockerfile.ngc.vllm
│ ├── Dockerfile.ngc.vllm0.8
│ ├── Dockerfile.ngc.vllm0.8.sagemaker
│ ├── Dockerfile.rocm
│ ├── Dockerfile.rocm_verl-0.3.0.post1
│ ├── Dockerfile.rocm_verl-0.4.1
│ ├── Dockerfile.sglang
│ ├── Dockerfile.vemlp.vllm.te
│ ├── Dockerfile.vllm.sglang.megatron.deepseek
│ ├── README.md
│ ├── verl0.4-cu124-torch2.6-fa2.7.4/
│ │ ├── Dockerfile.app.sglang.vllm.mcore0.12
│ │ ├── Dockerfile.app.sglang.vllm.mcore0.12.deepep
│ │ ├── Dockerfile.app.sglang.vllm.mcore0.13.preview
│ │ ├── Dockerfile.app.vllm.mcore0.12
│ │ ├── Dockerfile.app.vllm.mcore0.12.deepep
│ │ ├── Dockerfile.app.vllm.mcore0.13.preview
│ │ ├── Dockerfile.base
│ │ └── README.md
│ ├── verl0.5-cu126-torch2.7-fa2.7.4/
│ │ ├── Dockerfile.app.sglang.mcore0.12
│ │ ├── Dockerfile.app.vllm.mcore0.12
│ │ ├── Dockerfile.base.torch2.7.0
│ │ ├── Dockerfile.base.torch2.7.1
│ │ └── README.md
│ ├── verl0.5-cu126-torch2.7.1-fa2.8.0/
│ │ ├── Dockerfile.app.sglang.mcore0.12
│ │ ├── Dockerfile.app.sglang.mcore0.13.preview
│ │ ├── Dockerfile.base
│ │ └── README.md
│ └── verl0.5-preview-cu128-torch2.7.1-fa2.8.0/
│ ├── Dockerfile.app.sglang.megatron
│ ├── Dockerfile.base
│ └── README.md
├── docs/
│ ├── Makefile
│ ├── README.md
│ ├── README_vllm0.7.md
│ ├── README_vllm0.8.md
│ ├── _static/
│ │ └── js/
│ │ └── runllm-widget.js
│ ├── advance/
│ │ ├── agent_loop.rst
│ │ ├── checkpoint.rst
│ │ ├── dpo_extension.rst
│ │ ├── fsdp_extension.rst
│ │ ├── megatron_extension.rst
│ │ ├── one_step_off.md
│ │ ├── placement.rst
│ │ ├── ppo_lora.rst
│ │ ├── rollout_trace.rst
│ │ └── rope.rst
│ ├── algo/
│ │ ├── baseline.md
│ │ ├── dapo.md
│ │ ├── entropy.md
│ │ ├── gpg.md
│ │ ├── grpo.md
│ │ ├── opo.md
│ │ ├── ppo.md
│ │ ├── spin.md
│ │ └── sppo.md
│ ├── amd_tutorial/
│ │ ├── amd_build_dockerfile_page.rst
│ │ └── amd_vllm_page.rst
│ ├── api/
│ │ ├── data.rst
│ │ ├── single_controller.rst
│ │ ├── trainer.rst
│ │ └── utils.rst
│ ├── ascend_tutorial/
│ │ ├── ascend_profiling.rst
│ │ ├── ascend_profiling_en.rst
│ │ └── ascend_quick_start.rst
│ ├── conf.py
│ ├── examples/
│ │ ├── config.rst
│ │ ├── gsm8k_example.rst
│ │ ├── multi_modal_example.rst
│ │ ├── ppo_code_architecture.rst
│ │ └── sandbox_fusion_example.rst
│ ├── faq/
│ │ └── faq.rst
│ ├── hybrid_flow.rst
│ ├── index.rst
│ ├── perf/
│ │ ├── device_tuning.rst
│ │ ├── dpsk.md
│ │ ├── nsight_profiling.md
│ │ └── perf_tuning.rst
│ ├── preparation/
│ │ ├── prepare_data.rst
│ │ └── reward_function.rst
│ ├── requirements-docs.txt
│ ├── sglang_multiturn/
│ │ ├── interaction_system.rst
│ │ ├── multiturn.rst
│ │ ├── sandbox_fusion.rst
│ │ └── search_tool_example.rst
│ ├── single_controller.rst
│ ├── start/
│ │ ├── agentic_rl.rst
│ │ ├── install.rst
│ │ ├── more_resources.rst
│ │ ├── multinode.rst
│ │ ├── quickstart.rst
│ │ └── ray_debug_tutorial.rst
│ └── workers/
│ ├── fsdp_workers.rst
│ ├── megatron_workers.rst
│ ├── ray_trainer.rst
│ └── sglang_worker.rst
├── examples/
│ ├── data_preprocess/
│ │ ├── aime2024_multiturn_w_tool.py
│ │ ├── dapo_multiturn_w_tool.py
│ │ ├── full_hh_rlhf.py
│ │ ├── geo3k.py
│ │ ├── geo3k_multiturn_w_tool.py
│ │ ├── gsm8k.py
│ │ ├── gsm8k_multiturn_w_interaction.py
│ │ ├── gsm8k_multiturn_w_tool.py
│ │ ├── gsm8k_tool_agent_loop.py
│ │ ├── hellaswag.py
│ │ ├── math_dataset.py
│ │ ├── multiturn.py
│ │ └── preprocess_search_r1_dataset.py
│ ├── generation/
│ │ ├── run_deepseek7b_mutli_node.sh
│ │ └── run_deepseek_v2_lite_math.sh
│ ├── gpg_trainer/
│ │ ├── gpg.md
│ │ ├── run_qwen2-7b_math.sh
│ │ └── run_qwen2-7b_math_megatron.sh
│ ├── grpo_trainer/
│ │ ├── README.md
│ │ ├── run_deepseek671b_math_megatron.sh
│ │ ├── run_deepseek7b_llm.sh
│ │ ├── run_deepseek7b_llm_math.sh
│ │ ├── run_deepseek7b_llm_math_megatron.sh
│ │ ├── run_deepseek7b_llm_seq_balance.sh
│ │ ├── run_minicpmo2_6.sh
│ │ ├── run_moonlight16b_math_megatron.sh
│ │ ├── run_qwen2-7b.sh
│ │ ├── run_qwen2-7b_math.sh
│ │ ├── run_qwen2-7b_math_megatron.sh
│ │ ├── run_qwen2-7b_seq_balance.sh
│ │ ├── run_qwen2-7b_seq_balance_math_megatron.sh
│ │ ├── run_qwen2-7b_sgl_megatron.sh
│ │ ├── run_qwen2_5-3b_gsm8k_grpo_lora.sh
│ │ ├── run_qwen2_5-7b_math_megatron_diff_tp.sh
│ │ ├── run_qwen2_5_32b_grpo_npu.sh
│ │ ├── run_qwen2_5_7b_grpo_discrete_prof_npu.sh
│ │ ├── run_qwen2_5_7b_grpo_e2e_prof_npu.sh
│ │ ├── run_qwen2_5_7b_grpo_npu.sh
│ │ ├── run_qwen2_5_vl-7b-megatron.sh
│ │ ├── run_qwen2_5_vl-7b.sh
│ │ ├── run_qwen2_5_vl-7b_lora.sh
│ │ ├── run_qwen2_5_vl-7b_seq_balance.sh
│ │ ├── run_qwen2_5_vl_32b_npu.sh
│ │ ├── run_qwen2_5_vl_3b_npu.sh
│ │ ├── run_qwen2_5_vl_7b_npu.sh
│ │ ├── run_qwen3-236b_megatron.sh
│ │ ├── run_qwen3-8b.sh
│ │ └── run_qwen3moe-30b_megatron.sh
│ ├── ppo_trainer/
│ │ ├── README.md
│ │ ├── run_deepseek7b_llm.sh
│ │ ├── run_deepseek7b_llm_modelscope.sh
│ │ ├── run_deepseek7b_llm_pfppo.sh
│ │ ├── run_deepseek7b_llm_sandbox_fusion.sh
│ │ ├── run_deepseek7b_llm_sp2.sh
│ │ ├── run_deepseek_full_hh_rlhf.sh
│ │ ├── run_deepseek_math_gsm8k_megatron.sh
│ │ ├── run_deepseek_math_gsm8k_megatron_nsys.sh
│ │ ├── run_gemma.sh
│ │ ├── run_moonlight16b_a3b_gsm8k_megatron.sh
│ │ ├── run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh
│ │ ├── run_qwen2-7b_math_gsm8k_megatron.sh
│ │ ├── run_qwen2-7b_rm.sh
│ │ ├── run_qwen2-7b_rm_seq_balance.sh
│ │ ├── run_qwen2-7b_rm_seq_balance_fused_kernels.sh
│ │ ├── run_qwen2-7b_rm_seq_balance_nsys.sh
│ │ ├── run_qwen2-7b_seq_balance.sh
│ │ ├── run_qwen2-7b_sglang_seq_balance.sh
│ │ └── run_qwen2.5-32b.sh
│ ├── ray/
│ │ └── tutorial.ipynb
│ ├── reinforce_plus_plus_trainer/
│ │ ├── run_qwen2-7b_math_rf.sh
│ │ └── run_qwen2-7b_math_rf_baseline.sh
│ ├── remax_trainer/
│ │ ├── run_qwen2.5-3b_seq_balance.sh
│ │ └── run_qwen2.5-7b_seq_balance.sh
│ ├── rloo_trainer/
│ │ └── run_qwen2-7b.sh
│ ├── sft/
│ │ ├── gsm8k/
│ │ │ ├── run_deepseek_6b7.sh
│ │ │ ├── run_gemma_2b.sh
│ │ │ ├── run_gemma_7b.sh
│ │ │ ├── run_qwen2_5_05b_sft_peft_sp2_npu.sh
│ │ │ ├── run_qwen_05_peft.sh
│ │ │ ├── run_qwen_05_sp2.sh
│ │ │ └── run_qwen_05_sp2_liger.sh
│ │ └── multiturn/
│ │ └── run_qwen_05_sp2.sh
│ ├── sglang_multiturn/
│ │ ├── README.md
│ │ ├── config/
│ │ │ ├── geo3k_multiturn_grpo.yaml
│ │ │ ├── geo3k_multiturn_megatron_grpo.yaml
│ │ │ ├── gsm8k_multiturn_grpo.yaml
│ │ │ ├── gsm8k_multiturn_grpo_w_interaction.yaml
│ │ │ ├── gsm8k_multiturn_megatron_grpo.yaml
│ │ │ ├── interaction_config/
│ │ │ │ └── gsm8k_interaction_config.yaml
│ │ │ ├── retool_multiturn_grpo.yaml
│ │ │ ├── search_multiturn_grpo.yaml
│ │ │ └── tool_config/
│ │ │ ├── geo3k_tool_config.yaml
│ │ │ ├── gsm8k_tool_config.yaml
│ │ │ ├── mcp_server.json
│ │ │ ├── mcp_tool_config.yaml
│ │ │ ├── sandbox_fusion_tool_config.yaml
│ │ │ └── search_tool_config.yaml
│ │ ├── geo3k/
│ │ │ ├── run_qwen2.5-3b_geo3k_multiturn.sh
│ │ │ ├── run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh
│ │ │ └── run_qwen2.5-3b_megatron_geo3k_multiturn.sh
│ │ ├── run_qwen0.5b_gsm8k_multiturn_curriculum.sh
│ │ ├── run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh
│ │ ├── run_qwen2.5-3b_gsm8k_multiturn.sh
│ │ ├── run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh
│ │ ├── run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh
│ │ ├── run_qwen2.5-3b_megatron_gsm8k_multiturn.sh
│ │ ├── run_qwen3-4b_gsm8k_multiturn.sh
│ │ └── search_r1_like/
│ │ ├── local_dense_retriever/
│ │ │ ├── download.py
│ │ │ └── retrieval_server.py
│ │ └── run_qwen2.5-3b_instruct_search_multiturn.sh
│ ├── slurm/
│ │ └── ray_on_slurm.slurm
│ ├── split_placement/
│ │ ├── README.md
│ │ ├── config/
│ │ │ └── ppo_trainer_split.yaml
│ │ ├── main_ppo_split.py
│ │ ├── run_deepseek7b_llm.sh
│ │ └── split_monkey_patch.py
│ └── tuning/
│ ├── 0.5b/
│ │ └── qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh
│ ├── 1.5b/
│ │ └── qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh
│ ├── 14b/
│ │ ├── qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh
│ │ └── qwen2_14b_grpo_4_h800_fsdp_vllm.sh
│ ├── 32b/
│ │ ├── qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh
│ │ └── qwen2_32B_grpo_8_h20_megatron_vllm.sh
│ ├── 3b/
│ │ └── qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh
│ ├── 70b/
│ │ ├── qwen2-70b_grpo_32_h20_fsdp_vllm.sh
│ │ ├── qwen2-70b_grpo_32_h800_fsdp_vllm.sh
│ │ └── qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh
│ └── 7b/
│ ├── qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh
│ └── qwen2-7b_grpo_2_h800_fsdp_vllm.sh
├── init_ray.sh
├── init_ray_cluster.sh
├── pyproject.toml
├── recipe/
│ ├── README.md
│ ├── char_count/
│ │ ├── README.md
│ │ ├── create_dataset.py
│ │ ├── reward_function.py
│ │ ├── train_grpo.sh
│ │ └── train_sft.sh
│ ├── dapo/
│ │ ├── README.md
│ │ ├── config/
│ │ │ └── dapo_trainer.yaml
│ │ ├── dapo_ray_trainer.py
│ │ ├── main_dapo.py
│ │ ├── prepare_dapo_data.sh
│ │ ├── run_dapo_early_qwen2.5_32b.sh
│ │ ├── run_dapo_qwen2.5_32b.sh
│ │ ├── run_dapo_wo_ds_qwen2.5_32b.sh
│ │ ├── runtime_env.yaml
│ │ ├── test_dapo_7b.sh
│ │ ├── test_dapo_7b_math.sh
│ │ ├── test_dapo_7b_math_lora.sh
│ │ ├── test_dapo_7b_math_megatron.sh
│ │ ├── test_dapo_dspk_671b_megatron.sh
│ │ ├── test_dapo_qwen3_30b_math.sh
│ │ └── test_dapo_qwen3_30b_math_single_node.sh
│ ├── entropy/
│ │ ├── 32b_clip_cov.sh
│ │ ├── 32b_kl_cov.sh
│ │ ├── 32b_kl_cov_mininbsz.sh
│ │ ├── 7b_clip_cov.sh
│ │ ├── 7b_kl_cov.sh
│ │ ├── README.md
│ │ ├── config/
│ │ │ └── entropy_trainer.yaml
│ │ ├── entropy_ray_trainer.py
│ │ ├── main_entropy.py
│ │ ├── reward.py
│ │ └── reward_score/
│ │ ├── __init__.py
│ │ └── entropy_math/
│ │ ├── __init__.py
│ │ ├── grader.py
│ │ └── math_normalize.py
│ ├── genrm_remote/
│ │ ├── README.md
│ │ ├── reward_function.py
│ │ └── run_genrm_remote.sh
│ ├── langgraph_agent/
│ │ ├── __init__.py
│ │ ├── chat_model.py
│ │ ├── example/
│ │ │ ├── README.md
│ │ │ ├── agent.yaml
│ │ │ ├── create_dataset.py
│ │ │ ├── math_expression.py
│ │ │ └── run_qwen2.5_3b.sh
│ │ ├── react_agent_loop.py
│ │ └── test_react_agent_loop.py
│ ├── minicpmo/
│ │ └── rl_dataset.py
│ ├── one_step_off_policy/
│ │ ├── README.md
│ │ ├── config/
│ │ │ ├── one_step_off_ppo_megatron_trainer.yaml
│ │ │ └── one_step_off_ppo_trainer.yaml
│ │ ├── dapo_7b_math_fsdp2_4_12.sh
│ │ ├── dapo_7b_math_fsdp2_colocate.sh
│ │ ├── dapo_7b_math_megatron_4_12.sh
│ │ ├── dapo_7b_math_megatron_colocate.sh
│ │ ├── fsdp_workers.py
│ │ ├── grpo_0.6b_gsm8k_fsdp2_2_6.sh
│ │ ├── grpo_3b_gsm8k_fsdp2_2_6.sh
│ │ ├── main_ppo.py
│ │ ├── megatron_workers.py
│ │ ├── ray_trainer.py
│ │ └── vllm_sharding_manager.py
│ ├── onerec/
│ │ ├── main_onerec_ppo.py
│ │ ├── onerec_fsdp_workers.py
│ │ ├── onerec_ray_trainer.py
│ │ ├── onerec_recipe.py
│ │ ├── onerec_vllm_rollout.py
│ │ └── run_grpo.sh
│ ├── prime/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ └── prime_trainer.yaml
│ │ ├── main_prime.py
│ │ ├── prime_core_algos.py
│ │ ├── prime_dp_rm.py
│ │ ├── prime_fsdp_workers.py
│ │ ├── prime_ray_trainer.py
│ │ ├── run_prime_qwen.sh
│ │ └── run_prime_qwen_code.sh
│ ├── r1/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ └── evaluation.yaml
│ │ ├── data_process.py
│ │ ├── main_eval.py
│ │ ├── reward_score.py
│ │ ├── run_r1_distill_qwen.sh
│ │ └── tasks/
│ │ ├── __init__.py
│ │ ├── gpqa.py
│ │ ├── livecodebench.py
│ │ └── math.py
│ ├── retool/
│ │ ├── retool.py
│ │ ├── retool_multi_turn_sft_preprocess.py
│ │ ├── retool_sft_preprocess.py
│ │ ├── run_qwen2-32b_sft.sh
│ │ ├── run_qwen2.5_32b_sp8.sh
│ │ ├── run_qwen2.5_7b_sp4.sh
│ │ ├── run_qwen3_4b_sp4.sh
│ │ └── sandbox_fusion_tool_config.yaml
│ ├── spin/
│ │ ├── README.md
│ │ ├── config/
│ │ │ └── spin_trainer.yaml
│ │ ├── core_algos.py
│ │ ├── dp_actor.py
│ │ ├── fsdp_workers.py
│ │ ├── main_spin.py
│ │ ├── run_spin.sh
│ │ └── spin_trainer.py
│ └── sppo/
│ ├── README.md
│ ├── __init__.py
│ ├── config/
│ │ └── sppo_trainer.yaml
│ ├── dp_actor.py
│ ├── main_sppo.py
│ ├── run_qwen2.5-7b_rm.sh
│ ├── sppo_ray_trainer.py
│ └── sppo_worker.py
├── requirements-npu.txt
├── requirements.txt
├── requirements_sglang.txt
├── scripts/
│ ├── __init__.py
│ ├── converter_hf_to_mcore.py
│ ├── diagnose.py
│ ├── generate_trainer_config.sh
│ ├── init_random_model.py
│ ├── install_vllm_sglang_mcore.sh
│ ├── legacy_model_merger.py
│ ├── print_cfg.py
│ └── rollout_viewer.py
├── setup.py
├── tests/
│ ├── README.md
│ ├── __init__.py
│ ├── experimental/
│ │ └── agent_loop/
│ │ ├── agent_utils.py
│ │ └── test_basic_agent_loop.py
│ ├── interactions/
│ │ ├── __init__.py
│ │ ├── test_gsm8k_interaction.py
│ │ └── test_interaction_registry.py
│ ├── kill_github_tests.sh
│ ├── models/
│ │ ├── test_transformer.py
│ │ └── test_transformers_ulysses.py
│ ├── single_controller/
│ │ ├── __init__.py
│ │ ├── base/
│ │ │ └── test_decorator.py
│ │ ├── check_worker_alive/
│ │ │ └── main.py
│ │ ├── detached_worker/
│ │ │ ├── README.md
│ │ │ ├── client.py
│ │ │ ├── run.sh
│ │ │ └── server.py
│ │ ├── test_auto_padding_on_cpu.py
│ │ ├── test_colocated_workers.py
│ │ ├── test_colocated_workers_fused.py
│ │ ├── test_data_transfer.py
│ │ ├── test_decorator_on_cpu.py
│ │ ├── test_driverfunc_to_worker.py
│ │ ├── test_fused_workers_on_cpu.py
│ │ ├── test_high_level_scheduling_api.py
│ │ ├── test_ray_collectives.py
│ │ ├── test_ray_local_envs_on_cpu.py
│ │ ├── test_ray_utils_on_cpu.py
│ │ ├── test_rvdz.py
│ │ ├── test_worker_group_basics.py
│ │ └── test_worker_group_torch.py
│ ├── special_distributed/
│ │ ├── README.md
│ │ ├── run_all.sh
│ │ ├── test_fsdp_ckpt.py
│ │ └── test_tensor_dict.py
│ ├── special_e2e/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── check_custom_rwd_fn.py
│ │ ├── check_results.py
│ │ ├── envs/
│ │ │ ├── __init__.py
│ │ │ └── digit_completion/
│ │ │ ├── __init__.py
│ │ │ ├── task.py
│ │ │ └── tokenizer.py
│ │ ├── generation/
│ │ │ └── run_gen_qwen05.sh
│ │ ├── ppo_trainer/
│ │ │ ├── expert_parallel/
│ │ │ │ └── qwen2moe_minimal.json
│ │ │ ├── run_function_reward.sh
│ │ │ ├── run_model_reward.sh
│ │ │ ├── run_single_gpu.sh
│ │ │ └── run_single_gpu_with_engine.sh
│ │ ├── run_dapo.sh
│ │ ├── run_genrm_remote.sh
│ │ ├── run_geo3k_fsdp_sgl_multiturn_w_tool.sh
│ │ ├── run_grpo_lora_with_merge.sh
│ │ ├── run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh
│ │ ├── run_gsm8k_fsdp_sgl_multiturn_w_tool.sh
│ │ ├── run_one_step_off_policy.sh
│ │ ├── run_ppo_trainer_megatron.sh
│ │ ├── run_prime.sh
│ │ ├── run_r1_distill_qwen_aime24_eval.sh
│ │ ├── run_spin.sh
│ │ ├── run_sppo.sh
│ │ ├── run_test.sh
│ │ └── sft/
│ │ ├── run_sft.sh
│ │ └── test_sp_loss_match.py
│ ├── special_npu/
│ │ ├── run_qwen2_5_05b_dapo.sh
│ │ ├── run_qwen2_5_05b_grpo.sh
│ │ ├── run_qwen2_5_05b_sft_peft_sp2.sh
│ │ └── run_qwen2_5_vl_3b_npu.sh
│ ├── special_sanity/
│ │ ├── check_api_docs.py
│ │ ├── check_device_api_usage.py
│ │ ├── check_docs_time_info.py
│ │ ├── check_docstrings.py
│ │ ├── check_license.py
│ │ ├── check_pr_description.py
│ │ ├── check_pr_title.py
│ │ ├── test_config_docs.py
│ │ ├── test_import.py
│ │ ├── type_coverage_check.py
│ │ ├── validate_imported_docs.py
│ │ └── validate_structure.py
│ ├── special_standalone/
│ │ ├── README.md
│ │ └── test_memory_buffers.py
│ ├── test_base_config_on_cpu.py
│ ├── test_protocol_on_cpu.py
│ ├── tools/
│ │ └── test_base_tool_on_cpu.py
│ ├── trainer/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ ├── __init__.py
│ │ │ ├── legacy_ppo_megatron_trainer.yaml
│ │ │ ├── legacy_ppo_trainer.yaml
│ │ │ ├── test_algo_config_on_cpu.py
│ │ │ ├── test_critic_config_on_cpu.py
│ │ │ └── test_legacy_config_on_cpu.py
│ │ └── ppo/
│ │ ├── __init__.py
│ │ ├── test_core_algos_on_cpu.py
│ │ └── test_metric_utils_on_cpu.py
│ ├── utils/
│ │ ├── _test_module.py
│ │ ├── dataset/
│ │ │ ├── test_create_rl_sampler_on_cpu.py
│ │ │ ├── test_multiturn_sft_dataset_on_cpu.py
│ │ │ ├── test_rl_dataset_on_cpu.py
│ │ │ └── test_sft_dataset_on_cpu.py
│ │ ├── megatron/
│ │ │ └── test_pipeline_parallel.py
│ │ ├── reward_score/
│ │ │ ├── reward_score/
│ │ │ │ └── test_sandbox_fusion_on_cpu.py
│ │ │ └── test_sandbox_on_cpu.py
│ │ ├── test_activation_offload.py
│ │ ├── test_config_on_cpu.py
│ │ ├── test_flops_counter.py
│ │ ├── test_fs_on_cpu.py
│ │ ├── test_import_utils_on_cpu.py
│ │ ├── test_linear_cross_entropy.py
│ │ ├── test_linear_cross_entropy_tp.py
│ │ ├── test_model_on_cpu.py
│ │ ├── test_nvtx_profile.py
│ │ ├── test_rollout_trace_on_cpu.py
│ │ ├── test_seqlen_balancing.py
│ │ ├── test_temp_env_on_cpu.py
│ │ ├── test_timeout_decorator_cpu.py
│ │ └── test_torch_functional.py
│ └── workers/
│ ├── reward_manager/
│ │ └── test_registry_on_cpu.py
│ └── rollout/
│ ├── async_rollout_utils.py
│ ├── perf/
│ │ └── vllm_async_rollout.py
│ ├── resource/
│ │ └── tool_configs/
│ │ ├── mcp_server.json
│ │ ├── mcp_tool_config
│ │ ├── sandbox_fusion_tool_config
│ │ └── search_tool_config
│ ├── rollout_vllm/
│ │ ├── run_fsdp_vllm.py
│ │ ├── test_vllm_chat_scheduler.py
│ │ ├── test_vllm_model_rope_scaling.py
│ │ └── test_vllm_spmd.py
│ ├── test_async_sglang_server_on_cpu.py
│ ├── test_custom_completion_callback.py
│ ├── test_hf_rollout.py
│ ├── test_sglang_async_rollout_mcp_tools.py
│ ├── test_sglang_async_rollout_multimodal_delta.py
│ ├── test_sglang_async_rollout_search_tools.py
│ ├── test_sglang_async_rollout_sf_tools.py
│ ├── test_sglang_async_rollout_w_interaction.py
│ ├── test_sglang_async_rollout_w_tools.py
│ ├── test_sglang_multi_interaction.py
│ ├── test_sglang_rollout_sharding_manager.py
│ ├── test_sglang_spmd.py
│ └── utils_sglang.py
└── verl/
├── __init__.py
├── base_config.py
├── experimental/
│ ├── __init__.py
│ ├── agent_loop/
│ │ ├── __init__.py
│ │ ├── agent_loop.py
│ │ ├── single_turn_agent_loop.py
│ │ ├── tool_agent_loop.py
│ │ └── tool_parser.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ └── sampler.py
│ └── dynamic_dataset/
│ ├── __init__.py
│ └── dynamicgen_dataset.py
├── interactions/
│ ├── __init__.py
│ ├── base.py
│ ├── gsm8k_interaction.py
│ └── utils/
│ ├── __init__.py
│ └── interaction_registry.py
├── model_merger/
│ ├── __init__.py
│ ├── __main__.py
│ ├── base_model_merger.py
│ ├── fsdp_model_merger.py
│ └── megatron_model_merger.py
├── models/
│ ├── README.md
│ ├── __init__.py
│ ├── llama/
│ │ ├── __init__.py
│ │ └── megatron/
│ │ ├── __init__.py
│ │ ├── checkpoint_utils/
│ │ │ ├── __init__.py
│ │ │ ├── llama_loader.py
│ │ │ ├── llama_loader_depracated.py
│ │ │ └── llama_saver.py
│ │ ├── layers/
│ │ │ ├── __init__.py
│ │ │ ├── parallel_attention.py
│ │ │ ├── parallel_decoder.py
│ │ │ ├── parallel_linear.py
│ │ │ ├── parallel_mlp.py
│ │ │ └── parallel_rmsnorm.py
│ │ └── modeling_llama_megatron.py
│ ├── mcore/
│ │ ├── __init__.py
│ │ ├── config_converter.py
│ │ ├── loader.py
│ │ ├── mbridge.py
│ │ ├── model_forward.py
│ │ ├── model_forward_fused.py
│ │ ├── model_initializer.py
│ │ ├── patch_v012.py
│ │ ├── qwen2_5_vl/
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── model.py
│ │ │ ├── rope_utils.py
│ │ │ ├── vision_config.py
│ │ │ ├── vision_model.py
│ │ │ └── vision_transformer_block.py
│ │ ├── readme.md
│ │ ├── registry.py
│ │ ├── saver.py
│ │ ├── util.py
│ │ └── weight_converter.py
│ ├── qwen2/
│ │ ├── __init__.py
│ │ └── megatron/
│ │ ├── __init__.py
│ │ ├── checkpoint_utils/
│ │ │ ├── __init__.py
│ │ │ ├── qwen2_loader.py
│ │ │ ├── qwen2_loader_depracated.py
│ │ │ └── qwen2_saver.py
│ │ ├── layers/
│ │ │ ├── __init__.py
│ │ │ ├── parallel_attention.py
│ │ │ ├── parallel_decoder.py
│ │ │ ├── parallel_linear.py
│ │ │ ├── parallel_mlp.py
│ │ │ └── parallel_rmsnorm.py
│ │ └── modeling_qwen2_megatron.py
│ ├── registry.py
│ ├── transformers/
│ │ ├── __init__.py
│ │ ├── dense_common.py
│ │ ├── kimi_vl.py
│ │ ├── llama.py
│ │ ├── monkey_patch.py
│ │ ├── npu_patch.py
│ │ ├── qwen2.py
│ │ ├── qwen2_5_vl.py
│ │ └── qwen2_vl.py
│ └── weight_loader_registry.py
├── protocol.py
├── py.typed
├── single_controller/
│ ├── __init__.py
│ ├── base/
│ │ ├── __init__.py
│ │ ├── decorator.py
│ │ ├── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── worker.py
│ │ │ └── worker_group.py
│ │ ├── register_center/
│ │ │ ├── __init__.py
│ │ │ └── ray.py
│ │ ├── worker.py
│ │ └── worker_group.py
│ └── ray/
│ ├── __init__.py
│ ├── base.py
│ └── megatron.py
├── third_party/
│ ├── __init__.py
│ ├── sglang/
│ │ ├── __init__.py
│ │ └── parallel_state.py
│ ├── torch/
│ │ ├── __init__.py
│ │ └── distributed/
│ │ ├── __init__.py
│ │ ├── _state_dict_utils.py
│ │ └── checkpoint/
│ │ ├── __init__.py
│ │ └── state_dict.py
│ └── vllm/
│ └── __init__.py
├── tools/
│ ├── __init__.py
│ ├── base_tool.py
│ ├── geo3k_tool.py
│ ├── gsm8k_tool.py
│ ├── mcp_base_tool.py
│ ├── mcp_search_tool.py
│ ├── sandbox_fusion_tools.py
│ ├── schemas.py
│ ├── search_tool.py
│ └── utils/
│ ├── __init__.py
│ ├── mcp_clients/
│ │ ├── McpClientManager.py
│ │ └── utils.py
│ ├── search_r1_like_utils.py
│ └── tool_registry.py
├── trainer/
│ ├── __init__.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── _generated_ppo_megatron_trainer.yaml
│ │ ├── _generated_ppo_trainer.yaml
│ │ ├── actor/
│ │ │ ├── actor.yaml
│ │ │ ├── dp_actor.yaml
│ │ │ └── megatron_actor.yaml
│ │ ├── algorithm.py
│ │ ├── config.py
│ │ ├── critic/
│ │ │ ├── critic.yaml
│ │ │ ├── dp_critic.yaml
│ │ │ └── megatron_critic.yaml
│ │ ├── data/
│ │ │ └── legacy_data.yaml
│ │ ├── evaluation.yaml
│ │ ├── generation.yaml
│ │ ├── npu_profile/
│ │ │ └── npu_profile.yaml
│ │ ├── ppo_megatron_trainer.yaml
│ │ ├── ppo_trainer.yaml
│ │ ├── ref/
│ │ │ ├── dp_ref.yaml
│ │ │ ├── megatron_ref.yaml
│ │ │ └── ref.yaml
│ │ ├── reward_model/
│ │ │ ├── dp_reward_model.yaml
│ │ │ ├── megatron_reward_model.yaml
│ │ │ └── reward_model.yaml
│ │ ├── rollout/
│ │ │ └── rollout.yaml
│ │ └── sft_trainer.yaml
│ ├── constants_ppo.py
│ ├── fsdp_sft_trainer.py
│ ├── main_eval.py
│ ├── main_generation.py
│ ├── main_ppo.py
│ ├── ppo/
│ │ ├── __init__.py
│ │ ├── core_algos.py
│ │ ├── metric_utils.py
│ │ ├── ray_trainer.py
│ │ └── reward.py
│ └── runtime_env.yaml
├── utils/
│ ├── __init__.py
│ ├── activation_offload.py
│ ├── checkpoint/
│ │ ├── __init__.py
│ │ ├── checkpoint_manager.py
│ │ ├── fsdp_checkpoint_manager.py
│ │ └── megatron_checkpoint_manager.py
│ ├── config.py
│ ├── dataset/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── multiturn_sft_dataset.py
│ │ ├── rl_dataset.py
│ │ ├── rm_dataset.py
│ │ ├── sft_dataset.py
│ │ └── vision_utils.py
│ ├── debug/
│ │ ├── __init__.py
│ │ ├── performance.py
│ │ └── trajectory_tracker.py
│ ├── device.py
│ ├── distributed.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ └── torch_functional.py
│ ├── flops_counter.py
│ ├── fs.py
│ ├── fsdp_utils.py
│ ├── hdfs_io.py
│ ├── import_utils.py
│ ├── kernel/
│ │ ├── __init__.py
│ │ ├── kernels.py
│ │ └── linear_cross_entropy.py
│ ├── logger/
│ │ ├── __init__.py
│ │ └── aggregate_logger.py
│ ├── logging_utils.py
│ ├── megatron/
│ │ ├── __init__.py
│ │ ├── dist_checkpointing.py
│ │ ├── memory.py
│ │ ├── optimizer.py
│ │ ├── pipeline_parallel.py
│ │ ├── sequence_parallel.py
│ │ └── tensor_parallel.py
│ ├── megatron_utils.py
│ ├── memory_buffer.py
│ ├── metric/
│ │ ├── __init__.py
│ │ └── utils.py
│ ├── model.py
│ ├── net_utils.py
│ ├── profiler/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── empty_annotations.py
│ │ ├── mstx_profile.py
│ │ ├── nvtx_profile.py
│ │ ├── performance.py
│ │ └── profile.py
│ ├── py_functional.py
│ ├── ray_utils.py
│ ├── rendezvous/
│ │ ├── __init__.py
│ │ └── ray_backend.py
│ ├── reward_score/
│ │ ├── __init__.py
│ │ ├── geo3k.py
│ │ ├── gsm8k.py
│ │ ├── math.py
│ │ ├── math_batch.py
│ │ ├── math_dapo.py
│ │ ├── math_verify.py
│ │ ├── prime_code/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── testing_util.py
│ │ │ └── utils.py
│ │ ├── prime_math/
│ │ │ ├── __init__.py
│ │ │ ├── grader.py
│ │ │ └── math_normalize.py
│ │ ├── sandbox_fusion/
│ │ │ ├── __init__.py
│ │ │ └── utils.py
│ │ └── search_r1_like_qa_em.py
│ ├── rollout_trace.py
│ ├── seqlen_balancing.py
│ ├── tokenizer.py
│ ├── torch_dtypes.py
│ ├── torch_functional.py
│ ├── tracking.py
│ ├── ulysses.py
│ └── vllm_utils.py
├── version/
│ └── version
└── workers/
├── __init__.py
├── actor/
│ ├── __init__.py
│ ├── base.py
│ ├── dp_actor.py
│ └── megatron_actor.py
├── critic/
│ ├── __init__.py
│ ├── base.py
│ ├── dp_critic.py
│ └── megatron_critic.py
├── engine/
│ ├── __init__.py
│ ├── base.py
│ ├── fsdp/
│ │ ├── __init__.py
│ │ ├── engine_impl.py
│ │ └── utils.py
│ └── megatron/
│ ├── __init__.py
│ └── engine_impl.py
├── fsdp_workers.py
├── megatron_workers.py
├── reward_manager/
│ ├── __init__.py
│ ├── batch.py
│ ├── dapo.py
│ ├── naive.py
│ ├── prime.py
│ └── registry.py
├── reward_model/
│ ├── __init__.py
│ ├── base.py
│ └── megatron/
│ ├── __init__.py
│ └── reward_model.py
├── roles/
│ ├── __init__.py
│ ├── actor.py
│ └── critic.py
├── rollout/
│ ├── __init__.py
│ ├── async_server.py
│ ├── base.py
│ ├── chat_scheduler.py
│ ├── hf_rollout.py
│ ├── naive/
│ │ ├── __init__.py
│ │ └── naive_rollout.py
│ ├── schemas.py
│ ├── sglang_rollout/
│ │ ├── __init__.py
│ │ ├── async_sglang_server.py
│ │ ├── sglang_rollout.py
│ │ └── utils.py
│ ├── tokenizer.py
│ └── vllm_rollout/
│ ├── __init__.py
│ ├── vllm_async_server.py
│ └── vllm_rollout_spmd.py
└── sharding_manager/
├── __init__.py
├── base.py
├── fsdp_sglang.py
├── fsdp_ulysses.py
├── fsdp_vllm.py
├── megatron_sglang.py
└── megatron_vllm.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# IDE
.idea/
.vscode/
.claude/
.gemini/
*.swp
*.swo
*~
# OS
.DS_Store
Thumbs.db
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual environments
.venv/
venv/
ENV/
env/
# Logs
*.log
logs/
tmp_ray/
# Jupyter
.ipynb_checkpoints/
# Testing
.pytest_cache/
.coverage
htmlcov/
.tox/
.nox/
# ML/DL
wandb/
mlruns/
*.ckpt
*.pt
*.pth
*.bin
*.safetensors
output/
checkpoints/
ckpt/
# Data
# *.parquet
# *.csv
# *.json
# *.jsonl
# Ray
ray_results/
================================================
FILE: README.md
================================================
<div align="center">
<h1>OpenOneRec</h1>
<p align="center">
<strong>An Open Foundation Model and Benchmark to Accelerate Generative Recommendation</strong>
</p>
<p align="center">
<a href="https://huggingface.co/OpenOneRec">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-OneRec-ffc107?color=ffc107&logoColor=white" />
</a>
<a href="https://github.com/Kuaishou-OneRec/OpenOneRec">
<img alt="GitHub Code" src="https://img.shields.io/badge/GitHub-OpenOneRec-black?logo=github" />
</a>
<a href="https://arxiv.org/abs/2512.24762">
<img alt="Paper" src="https://img.shields.io/badge/Paper-ArXiv-b31b1b?logo=arxiv" />
</a>
<a href="#license">
<img alt="License" src="https://img.shields.io/badge/License-Apache%202.0-green" />
</a>
</p>
</div>
<br>
## 📖 Introduction
**OpenOneRec** is an open-source framework designed to bridge the gap between traditional recommendation systems and Large Language Models (LLMs). While Generative Recommendation has shown promise, existing models often struggle with isolated data silos and a lack of reasoning capabilities.
To address this, we introduce a unified framework that comprises:
* **RecIF-Bench**: The first holistic Recommendation Instruction-Following Benchmark, containing **100M interactions** from 200k users across heterogeneous domains (Short Video, Ads, Product).
* **OneRec-Foundation Models**: A family of models (1.7B & 8B) built on the Qwen3 backbone. The series includes **Standard** versions trained on our open-source dataset and **Pro** versions enhanced with a hundred-billion-token industrial corpus from Kuaishou.
* **Full-Stack Pipeline**: We open-source our comprehensive training pipeline, including data processing, co-pretraining, and post-training, to ensure full reproducibility and facilitate scaling law research in recommendation.
## 🔥 News
* **[2026.1.1]** 📑 **The technical report** has been released.
* **[2026.1.1]** 🎉 **OneRec-Foundation** models (1.7B, 8B) are now available on Hugging Face!
* **[2026.1.1]** 🚀 **RecIF-Bench** dataset and evaluation scripts are open-sourced.
## 📊 RecIF-Bench
We propose **RecIF-Bench** to rigorously assess the synergy between instruction following and domain-specific recommendation. It organizes 8 distinct tasks into a four-layer capability hierarchy:
* **Layer 0: Semantic Alignment** (Item Understanding)
* **Layer 1: Fundamental Prediction** (Short Video Rec, Ad Rec, Product Rec, Label Prediction)
* **Layer 2: Instruction Following** (Interactive Rec, Label-Conditional Rec)
* **Layer 3: Reasoning** (Recommendation Explanation)
The benchmark aggregates data from three domains: **Short Video** (Content), **Ads** (Commercial), and **Product** (E-commerce).
## 🤖 Model Zoo
The OpenOneRec-Foundation series is built upon the Qwen architecture, enhanced with **Itemic Tokens** for modality alignment and trained via a multi-stage protocol.
| Model | Backbone | Parameters | Description | Link |
| :--- | :--- | :--- | :--- | :--- |
| **OneRec-1.7B** | Qwen3-1.7B | 1.7B | Standard version trained on open-source data (~33B tokens) | [HuggingFace](https://huggingface.co/OpenOneRec/OneRec-1.7B) |
| **OneRec-8B** | Qwen3-8B | 8B | Standard version trained on open-source data (~33B tokens) | [HuggingFace](https://huggingface.co/OpenOneRec/OneRec-8B) |
| **OneRec-1.7B-Pro** | Qwen3-1.7B | 1.7B | Scaled-up version with expanded datasets (~130B tokens) | [HuggingFace](https://huggingface.co/OpenOneRec/OneRec-1.7B-pro) |
| **OneRec-8B-Pro** | Qwen3-8B | 8B | Scaled-up version with expanded datasets (~130B tokens) | [HuggingFace](https://huggingface.co/OpenOneRec/OneRec-8B-pro) |
## 🏗️ Method & Architecture
OpenOneRec reframes recommendation as a general-purpose sequence modeling paradigm.
### 1. Items as Tokens
To bridge the modality gap, we treat items as a distinct modality using **Itemic Tokens** derived from hierarchical vector quantization. This allows the LLM to process interaction history as a cohesive context sequence.
### 2. Training Pipeline
Our framework utilizes the following recipe:
* **Pre-Training**: Integrates collaborative signals via Itemic-Text Alignment and Full-Parameter Co-Pretraining.
* **Post-Training**:
* *Stage 1*: Multi-task Supervised Fine-tuning for basic instruction following.
* *Stage 2*: On-policy Distillation to restore general reasoning performance.
* *Stage 3*: Reinforcement Learning to enhance recommendation capabilities.
<div align="center">
<img src="assets/main_framework.png" width="80%" alt="OpenOneRec Overall Framework" />
<br>
<em>Figure: The Overall Framework of OpenOneRec.</em>
</div>
## 📈 Performance
### Results on RecIF-Bench
OpenOneRec-Foundation achieves **State-of-the-Art (SOTA)** results across RecIF-Bench tasks, significantly outperforming baselines like LC-Rec and TIGER.
| Task | Metric | SASRec | TIGER | LC-Rec | OneRec-1.7B | OneRec-8B | OneRec-1.7B-Pro | **OneRec-8B-Pro** |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| **Short Video Rec** | Recall@32 | 0.0119 | 0.0132 | 0.0180 | 0.0272 | 0.0355 | 0.0274 | **0.0369** |
| **Ad Rec** | Recall@32 | 0.0293 | 0.0581 | 0.0723 | 0.0707 | 0.0877 | 0.0735 | **0.0964** |
| **Product Rec** | Recall@32 | 0.0175 | 0.0283 | 0.0416 | 0.0360 | 0.0470 | 0.0405 | **0.0538** |
| **Label-Cond. Rec** | Recall@32 | 0.0140 | 0.0123 | 0.0170 | 0.0184 | 0.0228 | 0.0182 | **0.0235** |
| **Label Pred.** | AUC | 0.6244 | 0.6675 | 0.6139 | 0.6184 | 0.6615 | 0.6071 | **0.6912** |
| **Interactive Rec** | Recall@32 | -- | -- | 0.2394 | 0.1941 | 0.3032 | 0.2024 | **0.3458** |
| **Item Und.** | LLM Score | -- | -- | 0.2517 | 0.3175 | 0.3202 | 0.3133 | **0.3209** |
| **Rec. Explanation** | LLM Score | -- | -- | 3.9350 | 3.3540 | 3.6774 | 3.5060 | **4.0381** |
<div align="center">
<img src="assets/benchmark.png" width="80%" alt="Holistic Performance Overview of OpenOneRec." />
<br>
<em>Holistic Performance Overview of OpenOneRec.</em>
</div>
### Cross-Domain Transferability
On the **Amazon Benchmark** (10 datasets), OpenOneRec demonstrates exceptional zero-shot/few-shot transfer capabilities, achieving an average **26.8% improvement** in Recall@10 over the second-best method.
| Domain | SASRec | TIGER | LC-Rec | **Ours** |
| :--- | :--- | :--- | :--- | :--- |
| Baby | 0.0381 | 0.0318 | 0.0344 | **0.0513** |
| Beauty | 0.0639 | 0.0628 | 0.0764 | **0.0924** |
| Cell Phones | 0.0782 | 0.0786 | 0.0883 | **0.1036** |
| Grocery | 0.0789 | 0.0691 | 0.0790 | **0.1029** |
| Health | 0.0506 | 0.0534 | 0.0616 | **0.0768** |
| Home | 0.0212 | 0.0216 | 0.0293 | **0.0390** |
| Pet Supplies | 0.0607 | 0.0542 | 0.0612 | **0.0834** |
| Sports | 0.0389 | 0.0331 | 0.0418 | **0.0547** |
| Tools | 0.0437 | 0.0344 | 0.0438 | **0.0593** |
| Toys | 0.0658 | 0.0527 | 0.0549 | **0.0953** |
*Metric: Recall@10. Ours refers to OneRec-Foundation with text-augmented itemic tokens strategy. For implementation details, please refer to [GRLM](https://github.com/ZY0025/GRLM).*
## 🚀 Quick Start
*Code release and detailed usage instructions are coming soon.*
Currently, you can load our models using `transformers>=4.51.0`:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "OpenOneRec/OneRec-8B"
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
# prepare the model input
# case - prompt with itemic tokens
prompt = "这是一个视频:<|sid_begin|><s_a_340><s_b_6566><s_c_5603><|sid_end|>,帮我总结一下这个视频讲述了什么内容"
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# conduct text completion
# Note: In our experience, default decoding settings may be unstable for small models.
# For 1.7B, we suggest: top_p=0.95, top_k=20, temperature=0.75 (during 0.6 to 0.8)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
# parsing thinking content
try:
# rindex finding 151668 (</think>)
index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
index = 0
thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
print("thinking content:", thinking_content)
print("content:", content)
```
## 🛣️ Roadmap / Under Development
We are actively working on the following features:
- [ ] **General-domain data**: scripts to fetch and preprocess public general-domain corpora used in `data/general_text`.
- [ ] **Reproducible environments**: training pipeline Docker/Apptainer images for easier end-to-end reproduction.
- [ ] **One-click reproduction**: further code cleanup and streamlined training recipes for an end-to-end “run from scratch” experience.
- [ ] **Docs & tutorials**: improved documentation, tutorials, and best-practice guides.
- [ ] **Unified VeRL integration**: consolidate RL and distillation codepaths into a single, consistent VeRL-based implementation.
- [ ] **More model sizes**: support additional pretraining scales and configurations beyond current checkpoints.
Contributions are welcome! Please refer to the detailed documentation in each module.
## 📜 Citation
If you find our work helpful, please cite our technical report:
```bibtex
@misc{OpenOneRec,
title={OpenOneRec Technical Report},
author={Guorui Zhou and Honghui Bao and Jiaming Huang and Jiaxin Deng and Jinghao Zhang and Junda She and Kuo Cai and Lejian Ren and Lu Ren and Qiang Luo and Qianqian Wang and Qigen Hu and Rongzhou Zhang and Ruiming Tang and Shiyao Wang and Wuchao Li and Xiangyu Wu and Xinchen Luo and Xingmei Wang and Yifei Hu and Yunfan Wu and Zhanyu Liu and Zhiyang Zhang and Zixing Zhang and Bo Chen and Bin Wen and Chaoyi Ma and Chengru Song and Chenglong Chu and Defu Lian and Fan Yang and Feng Jiang and Hongtao Cheng and Huanjie Wang and Kun Gai and Pengfei Zheng and Qiang Wang and Rui Huang and Siyang Mao and Tingting Gao and Wei Yuan and Yan Wang and Yang Zhou and Yi Su and Zexuan Cheng and Zhixin Ling and Ziming Li},
year={2025},
eprint={2512.24762},
archivePrefix={arXiv},
primaryClass={cs.IR}
}
```
## 🛡️ License
The code in this repository is licensed under the Apache 2.0 License. The model weights are subject to their specific license agreements.
## 🙏 Acknowledgements
OpenOneRec is built upon and inspired by the open-source ecosystem. We would like to thank:
- **Qwen3**: for providing the base architecture and model initialization that OpenOneRec builds upon.
- **General-domain data sources**: for the public corpora referenced in [`data/general_text`](https://github.com/Kuaishou-OneRec/OpenOneRec/tree/main/data/general_text) used for mixed-domain training.
- **VeRL & PyTorch distributed training**: for the training infrastructure and scalable primitives (e.g., **FSDP**) used in post-training and large-scale runs.
We sincerely thank these projects for their outstanding work.
================================================
FILE: benchmarks/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2025] [OneRec Team]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: benchmarks/README.md
================================================
# Benchmark
## Quick Start
### Step 1: Install Dependencies
```bash
cd benchmarks
conda create -n benchmark python=3.10
conda activate benchmark
pip install uv
uv pip install torch==2.5.1 transformers==4.52.0 vllm==0.7.3
pip install -r requirements.txt
pip install -e . --no-deps --no-build-isolation
```
### Step 2: Start Ray Cluster (Optional)
```bash
# Initialize multi-node multi-GPU environment
# Skip this step if using single-node multi-GPU setup
bash scripts/init_ray_cluster.sh
```
### Step 3: Configure LLM API
Edit `api/config/llm_config.json` to fill in your Gemini configuration:
```json
{
"gemini": {
"project": "<your-project>",
"location": "<your-location>",
"credentials_path": "<path-to-credentials>",
...
}
}
```
**Note**: Only `project`, `location`, and `credentials_path` need to be configured.
Test the configuration:
```python
from api import get_client_from_config
# Create client
client = get_client_from_config("gemini")
# Generate text
response = client.generate("Tell me a joke")
print(response)
```
### Step 4: Run Evaluation
```bash
export BENCHMARK_BASE_DIR="."
export BENCHMARK_DATA_DIR="../raw_data/onerec_data/benchmark_data"
export DATA_VERSION="v1.0"
bash eval_script.sh <model_path> <result_name> <enable_thinking>
```
**Parameters**:
| Parameter | Description | Example |
|-----------|-------------|---------|
| model_path | Path to the model to evaluate | `model_output/sft/global_step10/converted` |
| result_name | Name identifier for output directory | `sft_nonthink` |
| enable_thinking | `true` or `false` | `false` |
**Examples**:
```bash
# Without thinking mode
bash eval_script.sh \
/path/to/model \
model_nonthink \
false
# With thinking mode
bash eval_script.sh \
/path/to/model \
model_think \
true
```
For debugging purposes, you can add `--sample_size 10` to each python command in `eval_script.sh` to run evaluation on a smaller subset of data.
### Step 5: View Results
After evaluation completes, results are saved in:
```
./results/v1.0/results_<result_name>/
```
Log files are located at:
```
./auto_eval_logs/v1.0/<result_name>.log
```
---
## Evaluation Tasks
| Task Name | Source | Description |
|-----------|--------|-------------|
| ad | Kuaishou Internal | 27,677 | Predict next clicked advertisement |
| product | Kuaishou Internal | 27,910 | Predict next clicked product |
| interactive | Kuaishou Internal | 1,000 | Predict next interacted video |
| video | Kuaishou Internal | 38,781 | Next video prediction |
| label_cond | Kuaishou Internal | 34,891 | Predict next video given specified consumption behavior |
| label_pred | Kuaishou Internal | 346,190 | Predict user engagement with video content |
| item_understand | Kuaishou Internal | 500 | Video SID to Caption generation task |
| rec_reason | Kuaishou Internal | 470 | Recommendation reason inference |
================================================
FILE: benchmarks/api/README.md
================================================
# Unified LLM API Wrapper
This is a unified LLM API wrapper library that provides a clean and elegant interface for calling different large language models.
## Supported Models
- **Claude** - Anthropic Claude models
- **Gemini** - Google Vertex AI Gemini models
- **DeepSeek** - DeepSeek models via Baidu Qianfan platform
## Model Pricing Comparison
- Claude: https://claude.com/pricing
- Gemini: https://ai.google.dev/gemini-api/docs/pricing
- DeepSeek: https://api-docs.deepseek.com/quick_start/pricing
## Quick Start
### Installation
```bash
pip install openai google-cloud-aiplatform anthropic tqdm
```
### Using Configuration File
First, edit `api/config/llm_config.json` to fill in your configuration:
Then use the following code to test:
```python
from api import get_client_from_config
# Create client
client = get_client_from_config("gemini")
# Generate text
response = client.generate("Tell me a joke")
print(response)
```
================================================
FILE: benchmarks/api/__init__.py
================================================
"""
Unified LLM API Wrapper
Supports convenient calling of Gemini, DeepSeek, and Claude models
"""
import json
from pathlib import Path
from typing import List, Dict, Any, Optional
from .base import BaseLLMClient
from .gemini import GeminiClient
from .deepseek import DeepSeekClient
from .claude import ClaudeClient
# Model mapping
MODEL_CLASSES = {
"gemini": GeminiClient,
"deepseek": DeepSeekClient,
"claude": ClaudeClient,
}
def load_config(config_path: str = None) -> Dict[str, Any]:
"""
Load configuration from JSON file
Args:
config_path: Configuration file path, defaults to api/config/llm_config.json
Returns:
dict: Configuration dictionary
Raises:
FileNotFoundError: Configuration file does not exist
json.JSONDecodeError: Configuration file format error
"""
if config_path is None:
current_dir = Path(__file__).parent
config_path = current_dir / "config" / "llm_config.json"
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Configuration file does not exist: {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def get_client(model: str, **config) -> BaseLLMClient:
"""
Factory function: Create LLM client instance
Args:
model: Model name ("gemini" or "deepseek")
**config: Model-specific configuration parameters
Returns:
BaseLLMClient: Client instance
Raises:
ValueError: Unsupported model type
Example:
>>> client = get_client("gemini",
... project="your-project",
... location="us-central1")
>>> result = client.generate("Tell me a joke")
"""
model = model.lower()
if model not in MODEL_CLASSES:
raise ValueError(
f"Unsupported model: {model}. "
f"Supported models: {', '.join(MODEL_CLASSES.keys())}"
)
client_class = MODEL_CLASSES[model]
return client_class(**config)
def get_client_from_config(
model: str,
config_path: Optional[str] = None
) -> BaseLLMClient:
"""
Create LLM client from configuration file
Args:
model: Model name ("gemini" or "deepseek")
config_path: Configuration file path, defaults to api/config/llm_config.json
Returns:
BaseLLMClient: Client instance
Raises:
ValueError: Model configuration not found in configuration file
Example:
>>> client = get_client_from_config("gemini")
>>> result = client.generate("Tell me a joke")
"""
config = load_config(config_path)
model = model.lower()
if model not in config:
raise ValueError(
f"Model '{model}' configuration not found in configuration file. "
f"Available models: {', '.join(config.keys())}"
)
model_config = config[model]
return get_client(model, **model_config)
def batch_generate(
prompts: List[str],
model: str,
max_workers: int = 5,
show_progress: bool = True,
config_path: Optional[str] = None,
**config
) -> List[Dict[str, Any]]:
"""
Batch generate text (with concurrent support)
Args:
prompts: List of prompts
model: Model name ("gemini" or "deepseek")
max_workers: Maximum number of concurrent threads, default 5
show_progress: Whether to show progress bar, default True
config_path: Configuration file path (if provided, use configuration file first)
**config: Model configuration parameters (if not using configuration file)
Returns:
List[Dict]: List of results, each element contains:
- prompt: Original prompt
- result: Generated text (on success)
- error: Error message (on failure)
- success: Whether successful
Example:
>>> # Using configuration file
>>> results = batch_generate(
... prompts=["Question 1", "Question 2", "Question 3"],
... model="gemini",
... max_workers=3
... )
>>> # Direct configuration
>>> results = batch_generate(
... prompts=["Question 1", "Question 2"],
... model="deepseek",
... api_key="your-key",
... appid="your-appid"
... )
"""
if config_path:
client = get_client_from_config(model, config_path)
else:
client = get_client(model, **config)
return client.batch_generate(
prompts=prompts,
max_workers=max_workers,
show_progress=show_progress
)
# Export all public interfaces
__all__ = [
# Classes
"BaseLLMClient",
"GeminiClient",
"DeepSeekClient",
"ClaudeClient",
# Functions
"get_client",
"get_client_from_config",
"batch_generate",
"load_config",
]
================================================
FILE: benchmarks/api/base.py
================================================
"""
Base LLM Client Definition
Provides unified interface specification with retry mechanism and batch processing
"""
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List
import time
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
class BaseLLMClient(ABC):
"""
Base class for LLM clients, defining unified interface
All concrete LLM clients (Gemini, DeepSeek, etc.) should inherit from this class
Provides unified retry mechanism and batch processing capabilities
"""
def __init__(self, **config):
"""
Initialize client
Args:
**config: Model-specific configuration parameters
"""
self.config = config
self.max_retries = config.get("max_retries", 3)
self.retry_delay = config.get("retry_delay", 2)
self._setup()
@abstractmethod
def _setup(self):
"""Setup client (subclasses implement specific initialization logic)"""
pass
@abstractmethod
def _call_api(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
Call API to generate text (subclasses implement specific API call logic)
Args:
prompt: Input prompt
temperature: Temperature parameter
max_tokens: Maximum number of tokens to generate
**kwargs: Other model-specific parameters
Returns:
Generated text content
Raises:
Exception: Raised when API call fails
"""
pass
def _is_retryable_error(self, error_msg: str) -> bool:
"""
Determine if error is retryable
Args:
error_msg: Error message
Returns:
bool: Whether the error is retryable
"""
retryable_keywords = [
'503', '429', '500', 'timeout', 'timed out', 'deadline',
'unavailable', 'failed to connect', 'connection',
'rate limit', 'overload'
]
return any(keyword in error_msg.lower() for keyword in retryable_keywords)
def _generate_with_retry(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
Generation method with retry mechanism (template method)
Args:
prompt: Input prompt
temperature: Temperature parameter
max_tokens: Maximum number of tokens to generate
**kwargs: Other parameters
Returns:
str: Generated text content
Raises:
Exception: Raised when API call fails
"""
if not prompt or not prompt.strip():
raise ValueError("prompt cannot be empty")
last_error = None
for attempt in range(self.max_retries):
try:
if attempt > 0:
delay = self.retry_delay * (2 ** (attempt - 1))
jitter = random.uniform(0, delay * 0.3)
time.sleep(delay + jitter)
return self._call_api(prompt, temperature, max_tokens, **kwargs)
except Exception as e:
last_error = e
error_msg = str(e)
is_retryable = self._is_retryable_error(error_msg)
if attempt == self.max_retries - 1 or not is_retryable:
raise Exception(f"{self.__class__.__name__} API call failed: {error_msg}")
print(f"{self.__class__.__name__} API call failed "
f"(attempt {attempt + 1}/{self.max_retries}), "
f"will retry in {self.retry_delay} seconds: {error_msg[:100]}")
raise Exception(f"Maximum retry attempts reached ({self.max_retries}): {last_error}")
def generate(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
Generate text content (public interface)
Args:
prompt: Input prompt
temperature: Temperature parameter (controls randomness)
max_tokens: Maximum number of tokens to generate
**kwargs: Other model-specific parameters
Returns:
str: Generated text content
Raises:
ValueError: Parameter error
Exception: API call failed
"""
return self._generate_with_retry(prompt, temperature, max_tokens, **kwargs)
def batch_generate(
self,
prompts: List[str],
max_workers: int = 5,
show_progress: bool = True,
**kwargs
) -> List[Dict[str, Any]]:
"""
Batch generate text (with concurrent support)
Args:
prompts: List of prompts
max_workers: Maximum number of concurrent threads, default 5
show_progress: Whether to show progress bar, default True
**kwargs: Other parameters to pass to generate
Returns:
List[Dict]: List of results, each element contains:
- prompt: Original prompt
- result: Generated text (on success)
- error: Error message (on failure)
- success: Whether successful
"""
try:
from tqdm import tqdm
has_tqdm = True
except ImportError:
has_tqdm = False
if show_progress:
print("Warning: tqdm not installed, cannot show progress bar")
def process_prompt(prompt: str, index: int) -> Dict[str, Any]:
try:
result = self.generate(prompt, **kwargs)
return {
"index": index,
"prompt": prompt,
"result": result,
"success": True
}
except Exception as e:
return {
"index": index,
"prompt": prompt,
"error": str(e),
"success": False
}
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index = {
executor.submit(process_prompt, prompt, i): i
for i, prompt in enumerate(prompts)
}
if show_progress and has_tqdm:
progress = tqdm(
as_completed(future_to_index),
total=len(prompts),
desc=f"Generating ({self.__class__.__name__})"
)
else:
progress = as_completed(future_to_index)
temp_results = []
for future in progress:
try:
result = future.result()
temp_results.append(result)
except Exception as e:
index = future_to_index[future]
temp_results.append({
"index": index,
"prompt": prompts[index],
"error": f"Task execution failed: {str(e)}",
"success": False
})
results = sorted(temp_results, key=lambda x: x["index"])
for r in results:
r.pop("index", None)
return results
def __repr__(self) -> str:
return f"{self.__class__.__name__}(config={self.config})"
================================================
FILE: benchmarks/api/claude.py
================================================
"""
Claude API Client Implementation
Based on Anthropic official SDK
"""
from typing import Optional
from anthropic import Anthropic
from .base import BaseLLMClient
class ClaudeClient(BaseLLMClient):
"""
Claude API Client
Example:
>>> client = ClaudeClient(
... api_key="your-api-key",
... model_name="claude-sonnet-4-20250514"
... )
>>> response = client.generate("Tell me a joke")
"""
def _setup(self):
"""Initialize Claude client"""
self.api_key = self.config.get("api_key")
self.model_name = self.config.get("model_name", "claude-sonnet-4-20250514")
self.base_url = self.config.get("base_url")
self.default_max_tokens = self.config.get("max_new_tokens", 1024)
self.default_temperature = self.config.get("temperature", 1.0)
if not self.api_key:
raise ValueError("api_key is a required parameter")
client_kwargs = {"api_key": self.api_key}
if self.base_url:
client_kwargs["base_url"] = self.base_url
self.client = Anthropic(**client_kwargs)
def _call_api(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
Call Claude API to generate text
Args:
prompt: Input prompt
temperature: Temperature parameter (0.0-1.0), default 1.0
max_tokens: Maximum number of tokens to generate, default 1024
**kwargs: Other Claude-specific parameters, such as:
- system: System prompt
- top_p: Nucleus sampling parameter
- top_k: Top-k sampling parameter
Returns:
str: Generated text content
Raises:
Exception: Raised when API call fails
"""
if temperature is None:
temperature = self.default_temperature
if max_tokens is None:
max_tokens = self.default_max_tokens
system = kwargs.pop("system", None)
request_params = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
}
if temperature is not None:
request_params["temperature"] = temperature
if system:
request_params["system"] = system
for key in ["top_p", "top_k", "stop_sequences"]:
if key in kwargs:
request_params[key] = kwargs.pop(key)
response = self.client.messages.create(**request_params)
if response and response.content:
text_blocks = [
block.text for block in response.content
if hasattr(block, 'text')
]
if text_blocks:
return "".join(text_blocks)
else:
raise Exception("API returned empty response")
else:
raise Exception("API returned invalid response")
================================================
FILE: benchmarks/api/config/llm_config.json
================================================
{
"gemini": {
"project": "",
"location": "",
"model_name": "gemini-2.5-flash-lite",
"credentials_path": "",
"max_new_tokens": 10000,
"temperature": 0.01,
"max_retries": 3,
"retry_delay": 2
},
"deepseek": {
"api_key": "",
"base_url": "",
"model_name": "deepseek-r1",
"appid": "",
"max_new_tokens": 10000,
"temperature": 0.01,
"max_retries": 3,
"retry_delay": 2
},
"claude": {
"api_key": "",
"base_url": "",
"model_name": "",
"max_new_tokens": 10000,
"temperature": 0.01,
"max_retries": 3,
"retry_delay": 2
}
}
================================================
FILE: benchmarks/api/deepseek.py
================================================
"""
DeepSeek API Client Implementation
Call DeepSeek model through Baidu Qianfan platform
"""
from typing import Optional
from openai import OpenAI
from .base import BaseLLMClient
class DeepSeekClient(BaseLLMClient):
"""
DeepSeek API Client (through Baidu Qianfan platform)
Example:
>>> client = DeepSeekClient(
... api_key="your-api-key",
... base_url="https://qianfan.baidubce.com/v2",
... model_name="deepseek-r1",
... appid="your-appid"
... )
>>> response = client.generate("Tell me a joke")
"""
def _setup(self):
"""Initialize DeepSeek client"""
self.api_key = self.config.get("api_key")
self.base_url = self.config.get("base_url", "https://qianfan.baidubce.com/v2")
self.model_name = self.config.get("model_name", "deepseek-r1")
self.appid = self.config.get("appid")
self.default_max_tokens = self.config.get("max_new_tokens", 300)
self.default_temperature = self.config.get("temperature", 0.7)
if not self.api_key:
raise ValueError("api_key is a required parameter")
if not self.appid:
raise ValueError("appid is a required parameter")
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url,
default_headers={"appid": self.appid}
)
def _call_api(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
Call DeepSeek API to generate text
Args:
prompt: Input prompt
temperature: Temperature parameter (0.0-2.0), default from config or 0.7
max_tokens: Maximum number of tokens to generate, default from config or 300
**kwargs: Other DeepSeek-specific parameters
Returns:
str: Generated text content
Raises:
Exception: Raised when API call fails
"""
if temperature is None:
temperature = self.default_temperature
if max_tokens is None:
max_tokens = self.default_max_tokens
request_params = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature,
"max_tokens": max_tokens,
"stream": False
}
request_params.update(kwargs)
response = self.client.chat.completions.create(**request_params)
if response and response.choices:
content = response.choices[0].message.content
if content:
return content
else:
raise Exception("API returned empty response")
else:
raise Exception("API returned invalid response")
================================================
FILE: benchmarks/api/example.py
================================================
"""
LLM API Usage Examples
Demonstrates various calling methods and use cases
"""
# ============================================================================
# Example 1: Using Configuration File (Simplest)
# ============================================================================
def example1_use_config():
"""Load and use from configuration file"""
from api import get_client_from_config
print("=" * 60)
print("Example 1: Using Configuration File")
print("=" * 60)
# Create client from configuration file
client = get_client_from_config("gemini")
# Generate text
response = client.generate("Explain what AI is in one sentence")
print(f"Answer: {response}\n")
# ============================================================================
# Example 2: Direct Parameters
# ============================================================================
def example2_direct_params():
"""Pass configuration parameters directly"""
from api import get_client
print("=" * 60)
print("Example 2: Direct Parameters")
print("=" * 60)
# Gemini
gemini_client = get_client(
"gemini",
project="your-project",
location="us-central1",
model_name="gemini-2.5-pro",
credentials_path="path/to/credentials.json"
)
# DeepSeek
deepseek_client = get_client(
"deepseek",
api_key="your-api-key",
appid="your-appid",
base_url="https://qianfan.baidubce.com/v2"
)
# Usage
response = gemini_client.generate("Hello")
print(f"Gemini: {response}\n")
# ============================================================================
# Example 3: Batch Generation (Concurrent)
# ============================================================================
def example3_batch_generate():
"""Batch text generation with concurrent support"""
from api import get_client_from_config
print("=" * 60)
print("Example 3: Batch Generation (Concurrent)")
print("=" * 60)
prompts = [
"What is machine learning?",
"Explain deep learning",
"Principles of neural networks",
"What is natural language processing?",
"Applications of computer vision"
]
# Use client instance's batch_generate method (recommended)
client = get_client_from_config("gemini")
results = client.batch_generate(
prompts=prompts,
max_workers=3, # 3 concurrent threads
show_progress=True # Show progress bar
)
# Process results
for i, item in enumerate(results, 1):
print(f"\nQuestion {i}: {item['prompt']}")
if item['success']:
print(f"Answer: {item['result'][:100]}...")
else:
print(f"Error: {item['error']}")
# ============================================================================
# Example 4: Custom Generation Parameters
# ============================================================================
def example4_custom_params():
"""Custom generation parameters"""
from api import get_client_from_config
print("=" * 60)
print("Example 4: Custom Generation Parameters")
print("=" * 60)
client = get_client_from_config("deepseek")
# Creative generation (high temperature)
creative = client.generate(
"Write a poem about spring",
temperature=0.9,
max_tokens=200
)
print(f"Creative output:\n{creative}\n")
# Precise generation (low temperature)
precise = client.generate(
"What is 1+1?",
temperature=0.1,
max_tokens=50
)
print(f"Precise output:\n{precise}\n")
# ============================================================================
# Example 5: Error Handling
# ============================================================================
def example5_error_handling():
"""Demonstrate error handling"""
from api import get_client_from_config
print("=" * 60)
print("Example 5: Error Handling")
print("=" * 60)
try:
client = get_client_from_config("gemini")
# Normal call
response = client.generate("Hello")
print(f"Success: {response}")
# Empty prompt (will raise ValueError)
response = client.generate("")
except ValueError as e:
print(f"Parameter error: {e}")
except Exception as e:
print(f"API call failed: {e}")
# ============================================================================
# Example 6: Switch Models
# ============================================================================
def example6_switch_models():
"""Switch between different models"""
from api import get_client_from_config
print("=" * 60)
print("Example 6: Switch Models")
print("=" * 60)
question = "What is quantum computing?"
for model_name in ["gemini", "deepseek"]:
try:
client = get_client_from_config(model_name)
response = client.generate(question)
print(f"\n{model_name.upper()}'s answer:")
print(response[:150] + "...")
except Exception as e:
print(f"\n{model_name} call failed: {e}")
# ============================================================================
# Example 7: Real Application - User Profile Generation
# ============================================================================
def example7_user_portrait():
"""Real application: Generate user profile based on user behavior"""
from api import get_client_from_config
print("=" * 60)
print("Example 7: User Profile Generation")
print("=" * 60)
# User behavior data
user_behavior = """
User's recently watched videos:
1. Machine Learning Tutorial
2. Python Programming Tips
3. Deep Learning Practical Projects
4. Data Analysis Case Studies
5. Latest AI Trends
"""
prompt = f"""Based on the following user behavior data, generate a concise user profile:
{user_behavior}
Requirements:
1. Summarize user's areas of interest
2. Infer user's skill level
3. Provide 3-5 precise tags
"""
client = get_client_from_config("gemini")
portrait = client.generate(prompt, temperature=0.5)
print("User Profile:")
print(portrait)
# ============================================================================
# Example 8: Direct Import of Classes
# ============================================================================
def example8_direct_import():
"""Import client classes directly"""
from api import GeminiClient, DeepSeekClient
print("=" * 60)
print("Example 8: Direct Import of Client Classes")
print("=" * 60)
# Direct instantiation
gemini = GeminiClient(
project="your-project",
location="us-central1"
)
deepseek = DeepSeekClient(
api_key="your-key",
appid="your-appid"
)
print("Clients created successfully")
print(f"Gemini client: {gemini}")
print(f"DeepSeek client: {deepseek}")
# ============================================================================
# Main Function
# ============================================================================
def main():
"""Run all examples"""
examples = [
("Using Configuration File", example1_use_config),
("Direct Parameters", example2_direct_params),
("Batch Generation", example3_batch_generate),
("Custom Parameters", example4_custom_params),
("Error Handling", example5_error_handling),
("Switch Models", example6_switch_models),
("User Profile Generation", example7_user_portrait),
("Direct Import Classes", example8_direct_import),
]
print("\n" + "=" * 60)
print("LLM API Usage Examples")
print("=" * 60)
print("\nAvailable examples:")
for i, (name, _) in enumerate(examples, 1):
print(f"{i}. {name}")
print("\nNote: Please ensure api/config/llm_config.json is configured before running")
print("\n" + "=" * 60 + "\n")
# Uncomment the lines below to run specific examples
# example1_use_config()
# example2_direct_params()
# example3_batch_generate()
# example4_custom_params()
# example5_error_handling()
# example6_switch_models()
# example7_user_portrait()
# example8_direct_import()
if __name__ == "__main__":
main()
================================================
FILE: benchmarks/api/gemini.py
================================================
"""
Gemini API Client Implementation
Based on Google Vertex AI's Gemini model
"""
import os
from typing import Optional
from vertexai.generative_models import GenerativeModel
import vertexai
from .base import BaseLLMClient
class GeminiClient(BaseLLMClient):
"""
Gemini API Client
Example:
>>> client = GeminiClient(
... project="your-project",
... location="us-central1",
... model_name="gemini-2.5-pro",
... credentials_path="path/to/credentials.json"
... )
>>> response = client.generate("Tell me a joke")
"""
def _setup(self):
"""Initialize Gemini client"""
self.project = self.config.get("project")
self.location = self.config.get("location")
self.model_name = self.config.get("model_name", "gemini-2.5-pro")
credentials_path = self.config.get("credentials_path")
self.default_max_tokens = self.config.get("max_new_tokens")
self.default_temperature = self.config.get("temperature")
if credentials_path:
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = credentials_path
if not self.project or not self.location:
raise ValueError("project and location are required parameters")
vertexai.init(project=self.project, location=self.location)
self.model = GenerativeModel(self.model_name)
def _call_api(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
Call Gemini API to generate text
Args:
prompt: Input prompt
temperature: Temperature parameter (0.0-1.0)
max_tokens: Maximum number of tokens to generate
**kwargs: Other Gemini-specific parameters
Returns:
str: Generated text content
Raises:
Exception: Raised when API call fails
"""
if temperature is None:
temperature = self.default_temperature
if max_tokens is None:
max_tokens = self.default_max_tokens
generation_config = {}
if temperature is not None:
generation_config["temperature"] = temperature
if max_tokens is not None:
generation_config["max_output_tokens"] = max_tokens
if generation_config:
response = self.model.generate_content(
prompt,
generation_config=generation_config
)
else:
response = self.model.generate_content(prompt)
if response and response.text:
return response.text
else:
raise Exception("API returned empty response")
================================================
FILE: benchmarks/benchmark/__init__.py
================================================
from benchmark.benchmark import Benchmark
from benchmark.base_generator import Generator
from benchmark.generation_runner import GenerationRunner
__version__ = "0.1.0"
__all__ = [
"Benchmark",
"Generator",
"GenerationRunner",
]
================================================
FILE: benchmarks/benchmark/base_generator.py
================================================
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional
from collections import defaultdict
from benchmark.console import *
# Global configuration: tasks that should disable optimizations (long prompts may cause issues)
# Used by vLLM-based generators to control chunked_prefill and prefix_caching
DISABLE_OPTIMIZATIONS_FOR_TASKS = ["rec_reason", "interactive"]
class Generator(ABC):
"""
Abstract base class for generation models
All generation models should inherit from this class.
Subclasses must implement _generate_standard() to support the generate() method.
"""
def __init__(
self,
**kwargs
):
"""
Args:
num_return_sequences: Number of candidates to generate per prompt
max_new_tokens: Maximum number of tokens to generate
**kwargs: Other generation parameters
"""
pass
def __str__(self) -> str:
"""
Return model name (for directory naming, remove path separators)
This method is shared across all generator implementations.
Subclasses must set self.model_name for this method to work.
Returns:
str: Model name
"""
return os.path.basename(self.model_name.rstrip('/'))
def generate(
self,
prompts: Dict[str, str],
**kwargs
) -> tuple:
"""
Batch text generation
Supports two-stage generation for recommendation tasks:
- Stage 1: Generate thinking content with top_p/top_k sampling (if thinking enabled)
- Stage 2: Generate SID sequences with beam search and prompt_token
This method is shared across all generator implementations to reduce code duplication.
Subclasses must implement _generate_standard() for this method to work.
Args:
prompts: {sample_id: prompt_text}
**kwargs: Optional generation parameters (will override initialization parameters)
Returns:
Tuple of two dicts:
- First dict: {sample_id: [generated_text_1, generated_text_2, ...]}
- Second dict: {sample_id: [cum_logprob_1, cum_logprob_2, ...]} (only for beam search)
"""
prompt_token = kwargs.get("prompt_token", None)
enable_thinking = kwargs.get("enable_thinking", False)
max_new_thinking_tokens = kwargs.get("max_new_thinking_tokens", None)
target_tokens = kwargs.get("target_tokens", None)
# Check if this is a classification task (has target_tokens parameter)
is_classification = target_tokens is not None
# Generation logic based on task type:
# A: has max_new_thinking_tokens + has prompt_token (recommendation tasks)
# B: has max_new_thinking_tokens + no prompt_token (caption tasks)
# C: no max_new_thinking_tokens (standard tasks)
# D: classification task + no think
# E: classification task + think
if is_classification:
# Classification task scenarios (D & E)
if enable_thinking:
# E: Classification with thinking
console.print(
f"Two-stage classification with thinking enabled: thinking (max_new_thinking_tokens={max_new_thinking_tokens}) + logprobs extraction for {target_tokens}",
style=warning_style,
)
return self._generate_two_stage_classification_with_thinking(prompts, **kwargs)
else:
# D: Classification without thinking
console.print(
f"Classification task: extracting logprobs for tokens {target_tokens}",
style=warning_style,
)
# Remove target_tokens from kwargs to avoid passing it twice
kwargs_classification = kwargs.copy()
kwargs_classification.pop("target_tokens", None)
results, _, mfu_stats = self.extract_token_logprobs(prompts, target_tokens, **kwargs_classification)
self.mfu_stats = mfu_stats
return results, {}
elif max_new_thinking_tokens:
if enable_thinking:
# A & B with thinking: two-stage generation
console.print(
f"Two-stage generation enabled: thinking (max_new_thinking_tokens={max_new_thinking_tokens}) + prompt_token ({prompt_token})",
style=warning_style,
)
return self._generate_two_stage_with_thinking(prompts, **kwargs)
else:
# A & B without thinking
if prompt_token:
# A without thinking: single-stage with prompt_token (beam search)
console.print(
f"Single-stage generation with prompt_token ({prompt_token})",
style=warning_style,
)
prompts_with_token = {
sample_id: prompt + prompt_token
for sample_id, prompt in prompts.items()
}
results, logprobs, mfu_stats = self._generate_standard(prompts_with_token, **kwargs)
self.mfu_stats = mfu_stats
return results, logprobs
else:
# B without thinking: single-stage sampling
console.print(
f"Warning: max_new_thinking_tokens={max_new_thinking_tokens} is set but "
f"enable_thinking=False and prompt_token=None. The max_new_thinking_tokens parameter will be ignored.",
style=warning_style,
)
results, logprobs, mfu_stats = self._generate_standard(prompts, **kwargs)
self.mfu_stats = mfu_stats
return results, logprobs
else:
# C: standard single-stage sampling
results, logprobs, mfu_stats = self._generate_standard(prompts, **kwargs)
self.mfu_stats = mfu_stats
return results, logprobs
def get_hardware_info(self) -> Dict[str, Any]:
"""
Get GPU hardware information for MFU calculation
Default implementation that works for all generators.
Handles both single-machine and Ray-based multi-machine setups.
Returns:
Dictionary containing:
- gpu_model: str, GPU model name
- gpu_count: int, total number of GPUs used
- gpu_tflops: float, theoretical peak TFLOPS for BF16/FP16
- tensor_parallel_size: int, tensor parallelism size
- gpu_memory_total_gb: float, total GPU memory in GB
"""
from benchmark.gpu_utils import get_gpu_info
gpu_info = get_gpu_info()
# Calculate total GPU count
tensor_parallel_size = getattr(self, 'tensor_parallel_size', 1)
# For Ray-based generators, multiply by number of workers
if hasattr(self, 'workers') and self.workers:
num_workers = len(self.workers)
total_gpus = num_workers * tensor_parallel_size
else:
# For single-machine generators
total_gpus = tensor_parallel_size
gpu_info["gpu_count"] = total_gpus
gpu_info["tensor_parallel_size"] = tensor_parallel_size
# Add worker info for Ray-based generators
if hasattr(self, 'workers'):
gpu_info["num_workers"] = len(self.workers) if self.workers else 0
return gpu_info
def _generate_two_stage_with_thinking(
self,
prompts: Dict[str, str],
**kwargs
) -> tuple:
"""
Two-stage generation with thinking mode
Stage 1: Generate thinking content with top_p/top_k sampling until </think>
Stage 2: Continue generation (with prompt_token if provided, beam search or sampling)
This method is shared across all generator implementations to reduce code duplication.
Subclasses must implement _generate_standard() for this method to work.
Args:
prompts: {sample_id: prompt_text}
**kwargs: Optional generation parameters
Returns:
Tuple of two dicts:
- First dict: {sample_id: [generated_text_1, generated_text_2, ...]}
- Second dict: {sample_id: [cum_logprob_1, cum_logprob_2, ...]} (only for beam search)
"""
prompt_token = kwargs.get("prompt_token", None)
console.print(
"Stage 1/2: Generating thinking content with top_p/top_k sampling...",
style=warning_style,
)
# Stage 1: Build kwargs for thinking generation (remove beam search, add stop)
kwargs_stage1 = kwargs.copy()
kwargs_stage1.pop("num_beams", None) # Remove beam search to force sampling mode
kwargs_stage1["stop"] = ["</think>"] # Stop at </think> tag
# Use num_return_thinking_sequences for stage 1 if specified
num_return_thinking = kwargs.get("num_return_thinking_sequences", 1)
kwargs_stage1["num_return_sequences"] = num_return_thinking
# Use max_new_thinking_tokens for stage 1 if specified
max_new_thinking_tokens = kwargs.get("max_new_thinking_tokens", 1000)
kwargs_stage1["max_new_tokens"] = max_new_thinking_tokens
# Call _generate_standard for stage 1 (ignoring logprobs as they're not used)
stage1_results, _, stage1_mfu_stats = self._generate_standard(prompts, **kwargs_stage1)
# Prepare prompts for stage 2 by appending thinking + prompt_token
# Each sample will have multiple thinking candidates
stage2_prompts = {}
sample_to_thinking_count = {} # Track how many thinking candidates each sample has
for sample_id, thinking_list in stage1_results.items():
# Use ALL thinking candidates (not just the first one)
sample_to_thinking_count[sample_id] = len(thinking_list)
for idx, thinking_text in enumerate(thinking_list):
# Create unique ID for each thinking candidate
thinking_sample_id = f"{sample_id}_thinking_{idx}"
# Append </think> + prompt_token (if provided)
# If model didn't generate </think>, treat entire output as thinking
if prompt_token:
full_thinking = thinking_text + "</think>\n" + prompt_token
else:
full_thinking = thinking_text + "</think>\n"
stage2_prompt = prompts[sample_id] + full_thinking
stage2_prompts[thinking_sample_id] = stage2_prompt
# Stage 2: Determine generation mode based on num_beams
kwargs_stage2 = kwargs.copy()
original_num_sequences = kwargs.get("num_return_sequences", 1)
original_num_beams = kwargs.get("num_beams", None)
# Determine if stage 2 uses beam search or sampling
use_beam_search_stage2 = original_num_beams is not None
if use_beam_search_stage2:
# Beam search mode: num_beams is directly used per thinking candidate
beams_per_thinking = original_num_beams
# Validate configuration: total sequences should match
if original_num_sequences != beams_per_thinking * num_return_thinking:
raise ValueError(
f"Configuration error: num_return_sequences ({original_num_sequences}) must equal "
f"num_beams ({beams_per_thinking}) * num_return_thinking_sequences ({num_return_thinking}) = "
f"{beams_per_thinking * num_return_thinking}. "
f"Please adjust your parameters accordingly."
)
kwargs_stage2["num_return_sequences"] = beams_per_thinking
kwargs_stage2["num_beams"] = beams_per_thinking
console.print(
f"Stage 2/2: Generating sequences with beam search for {len(stage2_prompts)} thinking candidates...",
style=warning_style,
)
console.print(
f"Each thinking candidate will use beam_width={beams_per_thinking}, return {beams_per_thinking} sequences "
f"({num_return_thinking} thinking × {beams_per_thinking} = {num_return_thinking * beams_per_thinking} total per sample)",
style=warning_style,
)
else:
# Sampling mode: each thinking generates 1 result
kwargs_stage2["num_return_sequences"] = 1
kwargs_stage2.pop("num_beams", None) # Remove num_beams to use sampling
console.print(
f"Stage 2/2: Generating sequences with sampling for {len(stage2_prompts)} thinking candidates...",
style=warning_style,
)
console.print(
f"Each thinking candidate will generate 1 sequence "
f"({num_return_thinking} thinking × 1 = {num_return_thinking} total per sample)",
style=warning_style,
)
# Call _generate_standard for stage 2
stage2_results, stage2_logprobs, stage2_mfu_stats = self._generate_standard(stage2_prompts, **kwargs_stage2)
# Merge mfu_stats from both stages
self.mfu_stats = {}
for sample_id, stats in stage1_mfu_stats.items():
self.mfu_stats[sample_id] = {
"input_tokens": stats["input_tokens"].copy(),
"output_tokens": stats["output_tokens"].copy(),
"times": stats["times"].copy()
}
# Group stage2 stats by original_id first
stage2_by_original = defaultdict(lambda: {"input_tokens": [], "output_tokens": [], "times": []})
for thinking_id, stats in stage2_mfu_stats.items():
original_id = thinking_id.rsplit("_thinking_", 1)[0]
stage2_by_original[original_id]["input_tokens"].extend(stats["input_tokens"])
stage2_by_original[original_id]["output_tokens"].extend(stats["output_tokens"])
stage2_by_original[original_id]["times"].extend(stats["times"])
# Aggregate: sum tokens, max time
for original_id, stats in stage2_by_original.items():
self.mfu_stats[original_id]["input_tokens"].append(sum(stats["input_tokens"]))
self.mfu_stats[original_id]["output_tokens"].append(sum(stats["output_tokens"]))
self.mfu_stats[original_id]["times"].append(max(stats["times"]))
# Merge results back by original sample_id
# Combine thinking + prompt_token + SID into final generation
final_results = defaultdict(list)
final_logprobs = defaultdict(list)
for thinking_sample_id, sid_sequences in stage2_results.items():
# Extract original sample_id and thinking index
# Format: "sampleID_thinking_N"
parts = thinking_sample_id.rsplit("_thinking_", 1)
original_sample_id = parts[0]
thinking_idx = int(parts[1])
# Get the corresponding thinking text from stage 1
thinking_text = stage1_results[original_sample_id][thinking_idx]
# Combine thinking + prompt_token + SID for each sequence
for sid_seq in sid_sequences:
# Format: <think>thinking_text</think>\n<|sid_begin|>sid_sequence
combined = f"{thinking_text}</think>\n{prompt_token or ''}{sid_seq}"
final_results[original_sample_id].append(combined)
# Also merge logprobs if available (from stage 2 beam search)
if thinking_sample_id in stage2_logprobs:
final_logprobs[original_sample_id].extend(stage2_logprobs[thinking_sample_id])
return (dict(final_results), dict(final_logprobs))
def _generate_two_stage_classification_with_thinking(
self,
prompts: Dict[str, str],
**kwargs
) -> tuple:
"""
Two-stage generation for classification tasks with thinking mode
Stage 1: Generate thinking content with top_p/top_k sampling until </think>
Stage 2: Extract logprobs for target tokens for each thinking candidate
This method is shared across all generator implementations to reduce code duplication.
Subclasses must implement _generate_standard() and extract_token_logprobs() for this method to work.
Args:
prompts: {sample_id: prompt_text}
**kwargs: Optional generation parameters
Returns:
Tuple of two dicts:
- First dict: {sample_id: ["<think>thinking_1</think>\n{'是': 0.8, '否': 0.2}", ...]}
- Second dict: {} (empty, no logprobs for classification)
"""
# target_tokens is guaranteed to be in kwargs (checked in generate() method)
target_tokens = kwargs["target_tokens"]
console.print(
"Stage 1/2: Generating thinking content with top_p/top_k sampling...",
style=warning_style,
)
# Stage 1: Build kwargs for thinking generation (remove beam search, add stop)
kwargs_stage1 = kwargs.copy()
kwargs_stage1.pop("num_beams", None) # Remove beam search to force sampling mode
kwargs_stage1.pop("target_tokens", None) # Remove target_tokens for stage 1
kwargs_stage1["stop"] = ["</think>"] # Stop at </think> tag
# Use num_return_thinking_sequences for stage 1 if specified
num_return_thinking = kwargs.get("num_return_thinking_sequences", 1)
kwargs_stage1["num_return_sequences"] = num_return_thinking
# Use max_new_thinking_tokens for stage 1 if specified
max_new_thinking_tokens = kwargs.get("max_new_thinking_tokens", 1000)
kwargs_stage1["max_new_tokens"] = max_new_thinking_tokens
# Call _generate_standard for stage 1 (ignoring logprobs as they're not used)
stage1_results, _, stage1_mfu_stats = self._generate_standard(prompts, **kwargs_stage1)
# Prepare prompts for stage 2 by appending thinking + </think>
# Each sample will have multiple thinking candidates
stage2_prompts = {}
sample_to_thinking_count = {} # Track how many thinking candidates each sample has
for sample_id, thinking_list in stage1_results.items():
# Use ALL thinking candidates (not just the first one)
sample_to_thinking_count[sample_id] = len(thinking_list)
for idx, thinking_text in enumerate(thinking_list):
# Create unique ID for each thinking candidate
thinking_sample_id = f"{sample_id}_thinking_{idx}"
# Append </think> to complete the thinking tag
full_thinking = thinking_text + f"</think>\n"
stage2_prompt = prompts[sample_id] + full_thinking
stage2_prompts[thinking_sample_id] = stage2_prompt
console.print(
f"Stage 2/2: Extracting logprobs for {len(stage2_prompts)} thinking candidates...",
style=warning_style,
)
console.print(
f"Each thinking candidate will extract logprobs for tokens {target_tokens} "
f"({num_return_thinking} thinking total per sample)",
style=warning_style,
)
# Build kwargs for stage 2 (remove target_tokens to avoid duplication)
kwargs_stage2 = kwargs.copy()
kwargs_stage2.pop("target_tokens", None)
# Call extract_token_logprobs for stage 2
stage2_probs, _, stage2_mfu_stats = self.extract_token_logprobs(stage2_prompts, target_tokens, **kwargs_stage2)
# Merge mfu_stats from both stages
self.mfu_stats = {}
for sample_id, stats in stage1_mfu_stats.items():
self.mfu_stats[sample_id] = {
"input_tokens": stats["input_tokens"].copy(),
"output_tokens": stats["output_tokens"].copy(),
"times": stats["times"].copy()
}
# Group stage2 stats by original_id first
stage2_by_original = defaultdict(lambda: {"input_tokens": [], "output_tokens": [], "times": []})
for thinking_id, stats in stage2_mfu_stats.items():
original_id = thinking_id.rsplit("_thinking_", 1)[0]
stage2_by_original[original_id]["input_tokens"].extend(stats["input_tokens"])
stage2_by_original[original_id]["output_tokens"].extend(stats["output_tokens"])
stage2_by_original[original_id]["times"].extend(stats["times"])
# Aggregate: sum tokens, max time
for original_id, stats in stage2_by_original.items():
self.mfu_stats[original_id]["input_tokens"].append(sum(stats["input_tokens"]))
self.mfu_stats[original_id]["output_tokens"].append(sum(stats["output_tokens"]))
self.mfu_stats[original_id]["times"].append(max(stats["times"]))
# Merge results back by original sample_id
# Combine thinking + probabilities into final generation
final_results = defaultdict(list)
for thinking_sample_id, json_str_list in stage2_probs.items():
# Extract original sample_id and thinking index
# Format: "sampleID_thinking_N"
parts = thinking_sample_id.rsplit("_thinking_", 1)
original_sample_id = parts[0]
thinking_idx = int(parts[1])
# Get the corresponding thinking text from stage 1
thinking_text = stage1_results[original_sample_id][thinking_idx]
# Extract JSON string from list (extract_token_logprobs returns [json_str])
json_str = json_str_list[0]
# Combine thinking + probabilities (json_str is already formatted)
# Format: "<think>thinking_text</think>\n{\"是\": 0.8, \"否\": 0.2}"
combined = f"{thinking_text}</think>\n{json_str}"
final_results[original_sample_id].append(combined)
return (dict(final_results), {})
class HfTransformersMixin:
"""
Mixin for HuggingFace Transformers functionality
Provides common parameter building logic for HuggingFace Transformers generate() API.
This mixin can be combined with Generator or RayMixin to create HuggingFace-based generators.
"""
def _build_sampling_params(self, **kwargs) -> tuple:
"""
Build HuggingFace sampling/generation parameters
Args:
**kwargs: Optional parameters to override default values
Returns:
Tuple of (gen_kwargs dict, stop_sequences list)
"""
n = kwargs.get("num_return_sequences")
max_tokens = kwargs.get("max_new_tokens")
num_beams = kwargs.get("num_beams", None)
use_beam_search = num_beams is not None
stop_sequences = kwargs.get("stop", [])
if use_beam_search:
# Beam search mode
if n and n > num_beams:
raise ValueError(
f"num_return_sequences ({n}) cannot be greater than num_beams ({num_beams}). "
f"Beam search can only return at most {num_beams} sequences. "
f"Please set num_return_sequences <= num_beams or increase num_beams."
)
gen_kwargs = {
"num_beams": num_beams,
"num_return_sequences": n if n else num_beams,
"max_new_tokens": max_tokens,
"do_sample": False,
"output_scores": True,
"return_dict_in_generate": True,
}
if "repetition_penalty" in kwargs:
gen_kwargs["repetition_penalty"] = kwargs["repetition_penalty"]
else:
# Sampling mode
gen_kwargs = {
"num_return_sequences": n,
"max_new_tokens": max_tokens,
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"top_k": kwargs.get("top_k", -1),
"repetition_penalty": kwargs.get("repetition_penalty", 1.0),
"presence_penalty": kwargs.get("presence_penalty", 0.0),
"frequency_penalty": kwargs.get("frequency_penalty", 0.0),
"do_sample": kwargs.get("do_sample", True),
}
return gen_kwargs, stop_sequences
class VllmMixin:
"""
Mixin for vLLM functionality
Provides common parameter building logic for vLLM generate() API.
This mixin can be combined with Generator or RayMixin to create vLLM-based generators.
"""
def _build_sampling_params(self, **kwargs):
"""
Build vLLM sampling parameters
Args:
**kwargs: Optional parameters to override default values
Returns:
SamplingParams or BeamSearchParams object
"""
from vllm import SamplingParams
from vllm.sampling_params import BeamSearchParams
temperature = kwargs.get("temperature", 0.7)
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", -1)
repetition_penalty = kwargs.get("repetition_penalty", 1.0)
presence_penalty = kwargs.get("presence_penalty", 0.0)
frequency_penalty = kwargs.get("frequency_penalty", 0.0)
max_tokens = kwargs.get("max_new_tokens")
n = kwargs.get("num_return_sequences", 1)
stop = kwargs.get("stop", None)
num_beams = kwargs.get("num_beams", None)
use_beam_search = num_beams is not None
if use_beam_search:
# Beam search: set beam_width to max(num_beams, n)
actual_beam_width = max(num_beams, n)
params = BeamSearchParams(
beam_width=actual_beam_width,
max_tokens=max_tokens,
)
else:
# Sampling mode
params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
stop=stop,
)
return params
def _should_enable_optimizations(self) -> bool:
"""
Determine whether to enable optimizations based on task types and force flags
This method is primarily used by vLLM-based generators to control
chunked_prefill and prefix_caching optimizations.
Returns:
True if should enable optimizations, False otherwise
"""
# Priority 1: Force flags
if self.force_enable_optimizations:
return True
if self.force_disable_optimizations:
return False
# Priority 2: Check if any task in task_types requires disabling optimizations
if hasattr(self, 'task_types') and self.task_types:
for task_type in self.task_types:
if task_type in DISABLE_OPTIMIZATIONS_FOR_TASKS:
return False
# Default: enable optimizations
return True
class RayMixin:
"""
Mixin for Ray distributed computing functionality
Provides Ray cluster management, GPU allocation, and resource cleanup
for distributed generators. This is a mixin class designed to be combined
with other generator classes using multiple inheritance.
"""
def _initialize_ray_cluster(self):
"""Initialize Ray cluster connection"""
import ray
if ray.is_initialized():
console.print(
" ✓ Ray already initialized",
style=success_style,
)
return
console.print(
" Initializing Ray cluster connection...",
style=subhead_style_2,
)
# Determine connection mode
if self.ray_address == "local":
# Local mode (single machine)
ray.init(ignore_reinit_error=True)
console.print(
" ✓ Ray initialized in local mode",
style=success_style,
)
elif self.ray_address == "auto":
# Auto-detect mode
try:
ray.init(address="auto", ignore_reinit_error=True)
console.print(
" ✓ Ray connected to existing cluster (auto-detected)",
style=success_style,
)
except Exception:
# Fallback to local mode
console.print(
" [yellow]No existing cluster found, initializing local mode...[/yellow]",
style=warning_style,
)
ray.init(ignore_reinit_error=True)
console.print(
" ✓ Ray initialized in local mode",
style=success_style,
)
else:
# Specific address
ray.init(address=self.ray_address, ignore_reinit_error=True)
console.print(
f" ✓ Ray connected to cluster at {self.ray_address}",
style=success_style,
)
def _determine_gpu_ids_from_cluster(self) -> List[Dict[str, Any]]:
"""
Determine GPU resources from Ray cluster
Returns:
List of GPU info dicts: [{"node_id": str, "gpu_index": int}, ...]
"""
import ray
# Get all nodes in cluster
nodes = ray.nodes()
# Collect GPU information from all nodes
gpu_list = []
for node in nodes:
if not node['Alive']:
continue
node_id = node['NodeID']
node_resources = node.get('Resources', {})
# Count GPUs on this node
num_gpus_on_node = int(node_resources.get('GPU', 0))
if num_gpus_on_node > 0:
# Add GPU entries for this node
for gpu_idx in range(num_gpus_on_node):
gpu_list.append({
"node_id": node_id,
"node_ip": node.get('NodeManagerAddress', 'unknown'),
"gpu_index": gpu_idx,
"global_index": len(gpu_list) # Global GPU index across cluster
})
if not gpu_list:
raise RuntimeError("No GPUs detected in Ray cluster")
# Apply user filters if specified
if self.gpu_ids is not None:
# In cluster mode, gpu_ids refers to global indices
filtered_list = []
for idx in self.gpu_ids:
if idx < len(gpu_list):
filtered_list.append(gpu_list[idx])
else:
console.print(
f" [yellow]Warning:[/yellow] GPU index {idx} out of range (max: {len(gpu_list)-1}), skipping",
style=warning_style,
)
gpu_list = filtered_list
elif self.num_gpus is not None:
# Limit to first num_gpus
if self.num_gpus < len(gpu_list):
gpu_list = gpu_list[:self.num_gpus]
elif self.num_gpus > len(gpu_list):
console.print(
f" [yellow]Warning:[/yellow] Requested {self.num_gpus} GPUs, but only {len(gpu_list)} available in cluster",
style=warning_style,
)
return gpu_list
def _group_gpus_for_workers(
self,
gpu_list: List[Dict[str, Any]],
tensor_parallel_size: int
) -> tuple:
"""
Group GPUs for workers, ensuring same-node constraint for tensor parallelism
Args:
gpu_list: List of GPU info dicts
tensor_parallel_size: Number of GPUs per worker
Returns:
(worker_gpu_groups, worker_node_assignments)
- worker_gpu_groups: List of GPU index lists for each worker
- worker_node_assignments: List of node IDs for each worker
"""
if len(gpu_list) % tensor_parallel_size != 0:
raise ValueError(
f"Number of GPUs ({len(gpu_list)}) must be divisible by tensor_parallel_size ({tensor_parallel_size})"
)
num_workers = len(gpu_list) // tensor_parallel_size
worker_gpu_groups = []
worker_node_assignments = []
if tensor_parallel_size == 1:
# Simple case: one GPU per worker
for gpu_info in gpu_list:
worker_gpu_groups.append([gpu_info["gpu_index"]])
worker_node_assignments.append(gpu_info["node_id"])
else:
# Complex case: multiple GPUs per worker
# Need to ensure all GPUs in a group are on the same node
if not self.allow_cross_node_tensor_parallel:
# Group by node first
node_to_gpus = {}
for gpu_info in gpu_list:
node_id = gpu_info["node_id"]
if node_id not in node_to_gpus:
node_to_gpus[node_id] = []
node_to_gpus[node_id].append(gpu_info)
# Create workers from each node
for node_id, node_gpus in node_to_gpus.items():
# Group GPUs on this node
for i in range(0, len(node_gpus), tensor_parallel_size):
if i + tensor_parallel_size <= len(node_gpus):
gpu_group = [gpu["gpu_index"] for gpu in node_gpus[i:i+tensor_parallel_size]]
worker_gpu_groups.append(gpu_group)
worker_node_assignments.append(node_id)
if len(worker_gpu_groups) != num_workers:
raise ValueError(
f"Cannot create {num_workers} workers with tensor_parallel_size={tensor_parallel_size} "
f"while ensuring same-node constraint. Got {len(worker_gpu_groups)} workers instead. "
f"Try setting --allow_cross_node_tensor_parallel or adjust tensor_parallel_size."
)
else:
# Allow cross-node tensor parallel (not recommended)
console.print(
" [yellow]Warning: Cross-node tensor parallelism enabled. This may cause performance degradation.[/yellow]",
style=warning_style,
)
for i in range(num_workers):
start_idx = i * tensor_parallel_size
end_idx = start_idx + tensor_parallel_size
gpu_group = [gpu_list[j]["gpu_index"] for j in range(start_idx, end_idx)]
worker_gpu_groups.append(gpu_group)
# Use first GPU's node as primary node
worker_node_assignments.append(gpu_list[start_idx]["node_id"])
return worker_gpu_groups, worker_node_assignments
def _display_cluster_info(self, gpu_list: List[Dict[str, Any]], num_workers: int):
"""Display cluster and GPU information"""
import ray
# Get cluster info
nodes = ray.nodes()
alive_nodes = [n for n in nodes if n['Alive']]
console.print(
f" Cluster nodes: [green]{len(alive_nodes)}[/green]",
style=subhead_style_2,
)
# Group GPUs by node
node_gpu_count = {}
for gpu_info in gpu_list:
node_ip = gpu_info["node_ip"]
node_gpu_count[node_ip] = node_gpu_count.get(node_ip, 0) + 1
for node_ip, count in node_gpu_count.items():
console.print(
f" - Node {node_ip}: {count} GPU(s)",
style=subhead_style_2,
)
console.print(
f" Total GPUs: [green]{len(gpu_list)}[/green]",
style=subhead_style_2,
)
console.print(
f" Tensor Parallel Size: [green]{self.tensor_parallel_size}[/green]",
style=subhead_style_2,
)
console.print(
f" Worker count: [green]{num_workers}[/green]",
style=subhead_style_2,
)
# Display worker assignments
console.print(
f" Worker GPU assignments:",
style=subhead_style_2,
)
for i, (gpu_group, node_id) in enumerate(zip(self.worker_gpu_groups, self.worker_node_assignments)):
# Find node IP for this node_id
node_ip = "unknown"
for gpu_info in gpu_list:
if gpu_info["node_id"] == node_id:
node_ip = gpu_info["node_ip"]
break
console.print(
f" - Worker {i}: GPUs {gpu_group} on node {node_ip}",
style=subhead_style_2,
)
def cleanup(self):
"""
Explicitly cleanup resources and release GPU memory
Called after generation tasks complete to release GPU memory occupied by Ray Workers.
This is useful for avoiding OOM errors during subsequent metric calculations.
"""
import ray
console.print(
"\nReleasing Ray Workers and resources...",
style=warning_style,
)
try:
# 1. Cleanup all Workers
if hasattr(self, 'workers') and self.workers:
for i, worker in enumerate(self.workers):
try:
ray.kill(worker)
console.print(
f" ✓ Worker {i} terminated",
style=success_style,
)
except Exception as e:
console.print(
f" ⚠ Worker {i} cleanup failed: {e}",
style=err_style,
)
self.workers = []
# 2. Shut down Ray (optional)
if ray.is_initialized():
console.print(
" Shutting down Ray...",
style=subhead_style_2,
)
ray.shutdown()
console.print(
" ✓ Ray shut down",
style=subhead_style_2,
)
console.print(
"✓ Resource cleanup completed\n",
style=success_style,
)
except Exception as e:
console.print(
f"✗ Cleanup process error: {e}",
style=err_style,
)
================================================
FILE: benchmarks/benchmark/benchmark.py
================================================
import os
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path
from datetime import datetime
from benchmark.console import *
from benchmark.generation_runner import GenerationRunner
from benchmark.base_generator import Generator
from benchmark.tasks import (
BenchmarkTable,
LATEST_BENCHMARK_VERSION,
check_benchmark_version,
check_task_types,
check_splits,
)
from benchmark.tasks.v1_0.registry import get_loader, get_evaluator, get_task_config
class DataLoaderWrapper:
"""Wrapper for unified data loading interface"""
def __init__(self, model_path: str, benchmark_version: str, data_dir: str, enable_thinking: Optional[bool] = None):
self.model_path = model_path
self._tokenizer = self._create_tokenizer(model_path) if model_path else None
self.benchmark_version = benchmark_version
self.data_dir = data_dir
self.enable_thinking = enable_thinking
self._loader_cache = {}
def _create_tokenizer(self, model_path: str):
"""Create tokenizer from model path"""
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
console.print(f"[green]Tokenizer loaded from: {model_path}[/green]")
return tokenizer
except Exception as e:
raise RuntimeError(f"Failed to load tokenizer from {model_path}: {e}")
def load_data(self, task_name: str, split: str = "test", sample_size: Optional[Any] = None):
"""Load data using new loader system"""
if task_name not in self._loader_cache:
self._loader_cache[task_name] = get_loader(
task_name=task_name,
data_dir=self.data_dir,
tokenizer=self._tokenizer,
enable_thinking=self.enable_thinking,
)
loader = self._loader_cache[task_name]
return loader.load_data(split=split, sample_size=sample_size)
class Benchmark:
"""
Benchmark Generation Task Evaluation Framework
Usage Example:
from benchmark import Benchmark
from your_generator import YourGenerator
benchmark = Benchmark(
data_dir="./data"
)
generator = YourGenerator("your-model-path")
benchmark.run(
generator=generator,
output_dir="./results"
)
"""
def __init__(
self,
model_path: Optional[str] = None,
task_types: Optional[List[str]] = None,
splits: Optional[List[str]] = None,
data_dir: Optional[str] = None,
enable_thinking: Optional[bool] = None,
):
"""Initialize evaluation framework"""
self.benchmark_version = LATEST_BENCHMARK_VERSION
self.data_dir = data_dir
self.task_types = check_task_types(task_types, self.benchmark_version)
self.splits = check_splits(splits, self.benchmark_version)
self.data_loader = DataLoaderWrapper(
model_path=model_path,
benchmark_version=self.benchmark_version,
data_dir=data_dir,
enable_thinking=enable_thinking,
)
@staticmethod
def print_benchmark_table():
"""Print all available benchmark versions and tasks"""
for benchmark_version in BenchmarkTable:
console.print(
head_print(f"Benchmark Dataset Version: {benchmark_version}"),
style=head_style,
justify="center",
)
task_types_list = list(BenchmarkTable[benchmark_version].keys())
total_task_types = len(task_types_list)
for task_idx, task_type in enumerate(task_types_list, start=1):
console.print(
f"\nTask Type [{task_idx}/{total_task_types}]: {task_type}\n",
style=subhead_style,
justify="center"
)
task_config = BenchmarkTable[benchmark_version][task_type]
console.print(
f"Dataset Name: {task_config.get('name', task_type)}",
style=row_style,
justify="center",
)
console.print(
f"Source: {task_config.get('source', 'N/A')}",
style=row_style,
justify="center",
)
console.print(
f"Splits: {task_config.get('splits', [])}",
style=row_style,
justify="center",
)
console.print(
f"Sample Size: {task_config.get('sample_size', 'N/A')}",
style=row_style,
justify="center",
)
console.print(
f"Description: {task_config.get('description', 'N/A')}",
style=row_style,
justify="center",
)
@staticmethod
def check_generator(generator):
"""Verify that generator implements required methods"""
required_methods = ["__str__", "generate"]
for method in required_methods:
if not hasattr(generator, method):
raise ValueError(f"Generator should have `{method}` method.")
if method != "__str__" and not callable(getattr(generator, method, None)):
raise ValueError(f"Generator.{method} should be callable.")
def run(
self,
generator: Generator,
output_dir: str = "./results",
overwrite: bool = False,
**kwargs
):
"""Run benchmark evaluation"""
self.check_generator(generator)
console.print(f"\n\nStarting generation\n\n", style=head_style, justify="center")
generation_runner = GenerationRunner(self.data_loader, overwrite=overwrite)
total_tasks = 0
completed_tasks = 0
task_table = BenchmarkTable[self.benchmark_version]
for task_name in self.task_types:
if task_name not in task_table:
continue
task_config = task_table[task_name]
available_splits = task_config.get("splits", ["test"])
for split in self.splits:
if split in available_splits:
total_tasks += 1
for task_name in self.task_types:
if task_name not in task_table:
console.print(f"Task does not exist: {task_name}")
continue
task_config = task_table[task_name]
available_splits = task_config.get("splits", ["test"])
# Iterate through all splits
for split in self.splits:
if split not in available_splits:
console.print(f"Split does not exist: {split} (task: {task_name})")
continue
# Determine displayed sample size
sample_size_param = kwargs.get('sample_size')
if sample_size_param is not None:
if sample_size_param == "full":
display_sample_size = task_config.get('size', 'N/A')
else:
display_sample_size = int(sample_size_param)
else:
display_sample_size = task_config.get('sample_size', 'N/A')
console.print(
f"\nTask [{completed_tasks + 1}/{total_tasks}]: {task_name} | Split: {split} | Sample Size: {display_sample_size}\n",
style=subhead_style,
justify="center",
)
try:
task_gen_config = task_config.get("generation_config", {})
prompt_config = task_config.get("prompt_config", {})
# Merge generation parameters (priority: user input > task config > Generator init parameters)
# Filter out None values from kwargs to avoid overwriting task config
valid_kwargs = {k: v for k, v in kwargs.items() if v is not None}
merged_kwargs = {**task_gen_config, **prompt_config, **valid_kwargs}
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]")
# Execute generation (without computing metrics)
generation_runner(
task_name=task_name,
split=split,
results_save_dir=output_dir,
generator=generator,
**merged_kwargs
)
completed_tasks += 1
except Exception as e:
import traceback
console.print(f"✗ Task failed: {task_name}/{split}", style=err_style)
console.print(f"✗ Error type: {type(e).__name__}", style=err_style)
console.print(f"✗ Error message: {str(e)}", style=err_style)
console.print("✗ Full stack trace:", style=err_style)
console.print(traceback.format_exc(), style=dim_style)
console.print(f"Total tasks: {total_tasks}")
console.print(f"Completed tasks: {completed_tasks}")
console.print(f"Failed tasks: {total_tasks - completed_tasks}")
console.print(f"Results saved to: {output_dir}")
@staticmethod
def _evaluate_single_task(
task_name: str,
task_dir: str,
generation_file: str,
split: str,
data_dir: str,
overwrite: bool,
valid_kwargs: Dict[str, Any],
cached_metrics: Optional[Dict[str, Any]] = None
) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
"""Evaluate a single task split"""
# Read generation results
with open(generation_file, 'r', encoding='utf-8') as f:
gen_data = json.load(f)
if "samples" not in gen_data:
raise ValueError("Generation result file missing 'samples' field (new format).")
samples = gen_data["samples"]
# Load task configuration
if task_name not in BenchmarkTable[LATEST_BENCHMARK_VERSION]:
console.print(f"⚠ Warning: Task '{task_name}' not found in BenchmarkTable[{LATEST_BENCHMARK_VERSION}], skipping...", style=warning_style)
return None, None, None
try:
evaluator_class = get_evaluator(task_name=task_name)
task_config = get_task_config(task_name=task_name)
task_config['evaluation_config'].update(valid_kwargs)
evaluator = evaluator_class(
samples=samples,
task_name=task_name,
predictions_dir=task_dir,
debug=True, # Enable debug mode for detailed info
task_config=task_config,
data_dir=data_dir,
overwrite=overwrite,
cached_metrics=cached_metrics
)
console.print(f"Using {evaluator_class.__name__} for {task_name}")
metrics, per_sample_metrics = evaluator.evaluate()
# Compute MFU metrics if hardware info and token stats are available
try:
from benchmark.tasks.v1_0.mfu_evaluator import compute_mfu_from_generation_data
mfu_metrics = compute_mfu_from_generation_data(gen_data)
if mfu_metrics:
metrics.update(mfu_metrics)
# Display MFU for each stage
if "mfu" in mfu_metrics:
mfu_list = mfu_metrics["mfu"]
if len(mfu_list) == 1:
console.print(f"✓ MFU: {mfu_list[0]:.2%}", style=success_style)
else:
mfu_values = [f"Stage{i+1}: {mfu:.2%}" for i, mfu in enumerate(mfu_list)]
console.print(f"✓ MFU (multi-stage): {', '.join(mfu_values)}", style=success_style)
except Exception as e:
console.print(f"⚠ Warning: MFU calculation failed: {e}", style=warning_style)
# Update samples with per-sample metrics
for sample_id, sample_metrics in per_sample_metrics.items():
if sample_id in samples:
samples[sample_id].update(sample_metrics)
# Write updated data back to generation result file
gen_data["samples"] = samples
with open(generation_file, 'w', encoding='utf-8') as f:
json.dump(gen_data, f, indent=2, ensure_ascii=False)
console.print(f"Updated sample metrics to: {generation_file}")
return gen_data, metrics, samples
except Exception as e:
console.print(f"✗ Error evaluating {task_name}: {e}", style=err_style)
console.print(f"Skipping task {task_name}", style=warning_style)
return None, None, None
@staticmethod
def _create_debug_file(generation_file: str, gen_data: Dict[str, Any], samples: Dict[str, Any], overwrite: bool = False) -> None:
"""Create debug file with first 100 samples"""
debug_file = f"{generation_file}.debug"
if overwrite or not os.path.exists(debug_file):
sorted_ids = sorted(samples.keys())
debug_sample_ids = sorted_ids[:100]
debug_samples = {id: samples[id] for id in debug_sample_ids}
debug_data = {
"model_name": gen_data.get("model_name", ""),
"task_name": gen_data.get("task_name", ""),
"split": gen_data.get("split", ""),
"total_time": gen_data.get("total_time", 0),
"avg_time_per_sample": gen_data.get("avg_time_per_sample", 0),
"samples": debug_samples,
}
with open(debug_file, 'w', encoding='utf-8') as f:
json.dump(debug_data, f, indent=2, ensure_ascii=False)
console.print(f"Created debug file: {debug_file}")
@staticmethod
def _calculate_model_total_time(model_results: Dict[str, Any]) -> float:
"""Calculate total time for all tasks of a model"""
model_total_time = 0
for task_name, task_results in model_results.items():
if task_name.startswith("_"):
continue
for split, split_metrics in task_results.items():
model_total_time += split_metrics.get("total_time", 0)
return model_total_time
@staticmethod
def _save_results_as_json(eval_results: Dict[str, Any], output_path: str) -> None:
"""Save evaluation results as JSON"""
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(eval_results, f, indent=2, ensure_ascii=False)
console.print(f"\n\n✓ Results Saved to {output_path}\n\n", style=success_style, justify="center")
@staticmethod
def _load_existing_results(output_path: str, task_types: List[str] = None) -> dict:
"""Load existing evaluation results from JSON file for incremental update"""
eval_results = {}
if os.path.exists(output_path) and output_path.endswith('.json'):
try:
with open(output_path, 'r', encoding='utf-8') as f:
eval_results = json.load(f)
console.print(f"✓ Loaded existing results from {output_path}", style=success_style, justify="center")
if task_types is not None:
console.print(f" Will update only specified tasks: {', '.join(task_types)}", style=success_style, justify="center")
except Exception as e:
console.print(f"⚠ Warning: Failed to load existing results: {e}", style=err_style, justify="center")
console.print(f" Starting with empty results", style=err_style, justify="center")
eval_results = {}
return eval_results
@staticmethod
def evaluate_dev(
generation_results_dir: str,
output_path: str = "./eval_results.json",
data_dir: str = None,
overwrite: bool = False,
task_types: List[str] = None,
**kwargs
):
"""Batch evaluate generated results and generate report"""
valid_kwargs = {k: v for k, v in kwargs.items() if v is not None}
console.print(f"\n\nMetric Calculation\n", style=head_style, justify="center")
console.print(f"Result Directory: {generation_results_dir}\n\n", style=head_style, justify="center")
if not os.path.exists(generation_results_dir):
console.print(f"✗ Error: Result Directory Not Found: {generation_results_dir}", style=err_style, justify="center")
return
eval_results = Benchmark._load_existing_results(output_path, task_types)
for model_name in os.listdir(generation_results_dir):
model_dir = os.path.join(generation_results_dir, model_name)
if not os.path.isdir(model_dir):
continue
if model_name not in eval_results:
eval_results[model_name] = {}
all_tasks = [t for t in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, t))]
if task_types is not None:
all_tasks = [t for t in all_tasks if t in task_types]
total_tasks_count = len(all_tasks)
for task_idx, task_name in enumerate(all_tasks, start=1):
task_dir = os.path.join(model_dir, task_name)
console.print(f"\nTask [{task_idx}/{total_tasks_count}]: {task_name}\n", style=subhead_style, justify="center")
if task_name not in eval_results[model_name]:
eval_results[model_name][task_name] = {}
for filename in os.listdir(task_dir):
if not filename.endswith('_generated.json'):
continue
split = filename.replace('_generated.json', '')
generation_file = os.path.join(task_dir, filename)
cached_metrics = eval_results.get(model_name, {}).get(task_name, {}).get(split, {})
# Evaluate single task
gen_data, metrics, samples = Benchmark._evaluate_single_task(
task_name=task_name,
task_dir=task_dir,
generation_file=generation_file,
split=split,
data_dir=data_dir,
overwrite=overwrite,
valid_kwargs=valid_kwargs,
cached_metrics=cached_metrics
)
if gen_data is None:
continue
Benchmark._create_debug_file(generation_file, gen_data, samples, overwrite)
eval_results[model_name][task_name][split] = {
**metrics,
"total_time": gen_data.get("total_time", 0),
"avg_time_per_sample": gen_data.get("avg_time_per_sample", 0),
}
model_total_time = Benchmark._calculate_model_total_time(eval_results[model_name])
eval_results[model_name]["_total_time"] = model_total_time
console.print(f"\n✓ Total time: {model_total_time:.2f}s ({model_total_time/60:.2f}min)\n", style=success_style)
Benchmark._save_results_as_json(eval_results, output_path)
================================================
FILE: benchmarks/benchmark/checkpoint_utils.py
================================================
"""
PT format model checkpoint loading tool
Supports loading PyTorch model checkpoints in non-safetensor format
"""
import torch
import hashlib
from pathlib import Path
from typing import Dict, Optional, Tuple, List
from difflib import SequenceMatcher
from benchmark.console import console
def match_checkpoint_keys_to_model(
checkpoint_keys: List[str],
model_keys: List[str],
similarity_threshold: float = 0.8
) -> Dict[str, str]:
"""
Intelligently match checkpoint key names to model key names
Args:
checkpoint_keys: List of key names in checkpoint
model_keys: List of key names in model
similarity_threshold: Similarity threshold
Returns:
Mapping dictionary {checkpoint_key: model_key}
"""
mapping = {}
for ckpt_key in checkpoint_keys:
# Try exact match first
if ckpt_key in model_keys:
mapping[ckpt_key] = ckpt_key
continue
# Try matching by removing "model." prefix
if ckpt_key.startswith("model."):
clean_key = ckpt_key[6:] # Remove "model."
if clean_key in model_keys:
mapping[ckpt_key] = clean_key
continue
# Try matching by adding "model." prefix
prefixed_key = f"model.{ckpt_key}"
if prefixed_key in model_keys:
mapping[ckpt_key] = prefixed_key
continue
# Use similarity matching
best_match = None
best_score = 0.0
for model_key in model_keys:
score = SequenceMatcher(None, ckpt_key, model_key).ratio()
if score > best_score and score >= similarity_threshold:
best_score = score
best_match = model_key
if best_match:
mapping[ckpt_key] = best_match
console.print(f"Similarity match: {ckpt_key} -> {best_match} (score: {best_score:.2f})")
return mapping
def check_embedding_weight_sharing(
state_dict: Dict[str, torch.Tensor],
verbose: bool = True
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
Check if embed_tokens and lm_head weights are shared
Args:
state_dict: Model state dictionary
verbose: Whether to print detailed information
Returns:
(is_shared, embed_key, lm_head_key)
"""
# Find embed_tokens and lm_head keys
embed_key = None
lm_head_key = None
for key in state_dict.keys():
if "embed_tokens.weight" in key:
embed_key = key
elif "lm_head.weight" in key:
lm_head_key = key
if not embed_key or not lm_head_key:
if verbose:
console.print(f"Complete weight pair not found: embed_tokens={embed_key}, lm_head={lm_head_key}")
return False, embed_key, lm_head_key
embed_tensor = state_dict[embed_key]
lm_head_tensor = state_dict[lm_head_key]
if verbose:
console.print(f"embed_tokens.weight shape: {embed_tensor.shape}")
console.print(f"lm_head.weight shape: {lm_head_tensor.shape}")
# Check if completely identical
is_shared = torch.equal(embed_tensor, lm_head_tensor)
if verbose:
if is_shared:
console.print("✓ embed_tokens and lm_head weights are identical (shared weights)")
else:
console.print("✗ embed_tokens and lm_head weights are different")
# Calculate difference statistics
diff = (embed_tensor != lm_head_tensor).sum().item()
total = embed_tensor.numel()
console.print(f" Different elements: {diff}/{total} ({diff/total*100:.2f}%)")
return is_shared, embed_key, lm_head_key
def handle_weight_tying(
state_dict: Dict[str, torch.Tensor],
model_keys: List[str],
new_state_dict: Dict[str, str]
) -> Dict[str, torch.Tensor]:
"""
Handle weight tying situations
In some models, embed_tokens and lm_head weights are tied
Args:
state_dict: Original state dictionary
model_keys: List of model key names
new_state_dict: Already mapped new state dictionary
Returns:
Updated state dictionary
"""
# Scenario 1: checkpoint has embed_tokens but no lm_head
if any("embed_tokens.weight" in k for k in state_dict.keys()):
embed_key = next((k for k in state_dict.keys() if "embed_tokens.weight" in k), None)
# Check if lm_head is missing in new_state_dict
has_lm_head = any("lm_head.weight" in k for k in new_state_dict.keys())
if not has_lm_head and embed_key:
# Try to find lm_head key in model
lm_head_candidates = ["lm_head.weight", "model.lm_head.weight"]
for candidate in lm_head_candidates:
if candidate in model_keys:
new_state_dict[candidate] = state_dict[embed_key]
console.print(f"✓ Weight tying: using {embed_key} to initialize {candidate}")
break
# Scenario 2: checkpoint has lm_head but no embed_tokens
if any("lm_head.weight" in k for k in state_dict.keys()):
lm_head_key = next((k for k in state_dict.keys() if "lm_head.weight" in k), None)
# Check if embed_tokens is missing in new_state_dict
has_embed = any("embed_tokens.weight" in k for k in new_state_dict.keys())
if not has_embed and lm_head_key:
# Try to find embed_tokens key in model
embed_candidates = ["embed_tokens.weight", "model.embed_tokens.weight"]
for candidate in embed_candidates:
if candidate in model_keys:
new_state_dict[candidate] = state_dict[lm_head_key]
console.print(f"✓ Weight tying: using {lm_head_key} to initialize {candidate}")
break
return new_state_dict
def load_weights_from_pt(
model: torch.nn.Module,
checkpoint_path: str,
device: str = "cpu",
strict: bool = False,
check_weight_sharing: bool = True,
handle_weight_tying_flag: bool = True
) -> Tuple[List[str], List[str]]:
"""
Load PT format checkpoint into model
Args:
model: Target model
checkpoint_path: Checkpoint file path
device: Loading device
strict: Whether to load strictly (requires all keys to match)
check_weight_sharing: Whether to check weight sharing
handle_weight_tying_flag: Whether to handle weight tying
Returns:
(missing_keys, unexpected_keys) Missing keys and unexpected keys
"""
console.print(f"Loading checkpoint: {checkpoint_path}")
# 1. Load checkpoint
try:
state_dict = torch.load(checkpoint_path, map_location=device)
except Exception as e:
console.print(f"Failed to load checkpoint: {e}")
raise
# 2. Extract model state dictionary
if 'model_state_dict' in state_dict:
console.print("Detected 'model_state_dict' key, extracting nested state dictionary")
state_dict = state_dict['model_state_dict']
elif 'state_dict' in state_dict:
console.print("Detected 'state_dict' key, extracting nested state dictionary")
state_dict = state_dict['state_dict']
checkpoint_keys = list(state_dict.keys())
model_keys = list(model.state_dict().keys())
console.print(f"Checkpoint key count: {len(checkpoint_keys)}")
console.print(f"Model key count: {len(model_keys)}")
if check_weight_sharing:
check_embedding_weight_sharing(state_dict, verbose=True)
console.print("Starting to match checkpoint key names to model key names...")
key_mapping = match_checkpoint_keys_to_model(checkpoint_keys, model_keys)
matched_count = len(key_mapping)
console.print(f"Successfully matched: {matched_count}/{len(checkpoint_keys)} keys")
new_state_dict = {}
skipped_keys = []
for ckpt_key in checkpoint_keys:
target_key = key_mapping.get(ckpt_key)
if target_key is None:
skipped_keys.append(ckpt_key)
continue
new_state_dict[target_key] = state_dict[ckpt_key]
if skipped_keys:
console.print(f"Skipped {len(skipped_keys)} unmatched keys")
if len(skipped_keys) <= 10:
console.print(f"Skipped keys: {skipped_keys}")
if handle_weight_tying_flag:
new_state_dict = handle_weight_tying(state_dict, model_keys, new_state_dict)
console.print("Loading state dictionary into model...")
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=strict)
if missing_keys:
console.print(f"Missing keys ({len(missing_keys)}): {missing_keys[:10]}{'...' if len(missing_keys) > 10 else ''}")
else:
console.print("✓ No missing keys")
if unexpected_keys:
console.print(f"Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:10]}{'...' if len(unexpected_keys) > 10 else ''}")
else:
console.print("✓ No unexpected keys")
console.print(f"✓ Checkpoint loading completed")
return missing_keys, unexpected_keys
def build_model_from_pt(
config_path: str,
checkpoint_path: str,
device: str = "cuda",
torch_dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = True
) -> torch.nn.Module:
"""
Create model from config and load PT checkpoint
This is the unified function used by both HfTransformersGenerator and RayHfTransformersGenerator.
Args:
config_path: Model configuration path
checkpoint_path: PT checkpoint path
device: Target device
torch_dtype: Data type
trust_remote_code: Whether to trust remote code
Returns:
Model with checkpoint loaded
"""
from transformers import AutoConfig, AutoModelForCausalLM
config = AutoConfig.from_pretrained(
config_path,
trust_remote_code=trust_remote_code
)
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code
)
if torch_dtype is not None:
model = model.to(torch_dtype)
if device != 'cpu':
model = model.to(device)
target_load_device = device if device != 'cpu' else 'cpu'
load_weights_from_pt(
model=model,
checkpoint_path=checkpoint_path,
device=target_load_device,
strict=False,
check_weight_sharing=True,
handle_weight_tying_flag=True
)
return model
def build_model_from_hf(
model_name_or_path: str,
device: str = "cuda",
torch_dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = True,
use_device_map: bool = True
) -> torch.nn.Module:
"""
Load pretrained model from HuggingFace
This is the unified function used by both HfTransformersGenerator and RayHfTransformersGenerator.
Args:
model_name_or_path: Model name or path
device: Target device
torch_dtype: Data type
trust_remote_code: Whether to trust remote code
use_device_map: Whether to use device_map="auto" for multi-GPU
Returns:
Loaded model
"""
from transformers import AutoModelForCausalLM
should_use_device_map = use_device_map and device != "cpu" and "cuda" in device
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch_dtype,
device_map="auto" if should_use_device_map else None,
trust_remote_code=trust_remote_code
)
if not should_use_device_map:
model = model.to(device)
return model
def export_pt_to_safetensor(
config_path: str,
checkpoint_path: str,
output_dir: Optional[str] = None,
trust_remote_code: bool = True,
use_cache: bool = True
) -> str:
"""
Convert PT checkpoint to HuggingFace format for vLLM compatibility
Args:
config_path: Model configuration path (HuggingFace model path or local config)
checkpoint_path: PT checkpoint path
output_dir: Output directory for converted model (optional, will use /tmp if not specified)
trust_remote_code: Whether to trust remote code
use_cache: Whether to use cached conversion (skip if already converted)
Returns:
Path to converted HuggingFace format model
"""
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
hash_input = f"{config_path}_{checkpoint_path}".encode('utf-8')
hash_suffix = hashlib.md5(hash_input).hexdigest()[:16]
if output_dir is None:
output_dir = f"/tmp/hf_checkpoint_{hash_suffix}"
temp_model_path = Path(output_dir) / "converted_model"
if use_cache and temp_model_path.exists():
has_config = (temp_model_path / "config.json").exists()
has_weights = (
(temp_model_path / "model.safetensors").exists() or
(temp_model_path / "pytorch_model.bin").exists() or
any(temp_model_path.glob("*.safetensors")) or
any(temp_model_path.glob("pytorch_model*.bin"))
)
if has_config and has_weights:
console.print(
f"✓ Found converted model, skipping conversion",
)
console.print(
f" Converted model path: {temp_model_path}",
)
return str(temp_model_path)
# Create output directory
temp_model_path.mkdir(parents=True, exist_ok=True)
console.print(f" Output directory: {temp_model_path}")
try:
# 1. Load configuration
console.print(" [1/4] Loading model configuration...")
config = AutoConfig.from_pretrained(
config_path,
trust_remote_code=trust_remote_code
)
# 2. Create model from config
console.print(" [2/4] Initializing model...")
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code
)
# 3. Load checkpoint
console.print(" [3/4] Loading PT checkpoint...")
load_weights_from_pt(
model=model,
checkpoint_path=checkpoint_path,
device='cpu',
strict=False,
check_weight_sharing=True,
handle_weight_tying_flag=True
)
# 4. Save as HuggingFace format
console.print(" [4/4] Saving as HuggingFace format...")
model.save_pretrained(temp_model_path, safe_serialization=True)
# Save tokenizer
tokenizer = AutoTokenizer.from_pretrained(
config_path,
trust_remote_code=trust_remote_code
)
tokenizer.save_pretrained(temp_model_path)
console.print(f"✓ Model conversion completed: {temp_model_path}")
return str(temp_model_path)
except Exception as e:
console.print(f"✗ Conversion failed: {e}")
# Clean up on failure
import shutil
if temp_model_path.exists():
shutil.rmtree(temp_model_path)
raise
================================================
FILE: benchmarks/benchmark/console.py
================================================
from rich.console import Console
from pyfiglet import Figlet
console = Console()
err_style = "bold red"
warning_style = "bold yellow"
success_style = "green"
dim_style = "dim"
# benchmark dataset
f = Figlet(font='digital')
head_print = lambda x : f.renderText(x)
head_style = "bold white on blue"
subhead_style = "bold black on bright_blue"
row_style = "black on bright_white"
# Generator styles
head_style_2 = "bold white on magenta"
subhead_style_2 = "white"
================================================
FILE: benchmarks/benchmark/generation_runner.py
================================================
"""
Generation Runner
Responsible for:
1. Loading test data via data loader
2. Calling Generator to produce model outputs
3. Saving generation results to JSON files
Note: Does NOT compute evaluation metrics (handled by task-specific evaluators)
"""
import json
import os
import time
from typing import Dict, List, Optional, Any
from pathlib import Path
from benchmark.console import *
from benchmark.base_generator import Generator
from benchmark.tasks.v1_0.base_loader import BaseLoader
class GenerationRunner:
"""
Generation task runner
Orchestrates the generation phase of evaluation:
- Loads test data via data loader
- Calls generator to produce model outputs
- Saves generation results to disk
Evaluation metrics are computed separately by task-specific evaluators.
"""
def __init__(
self,
data_loader: BaseLoader,
overwrite: bool = False
):
"""
Args:
data_loader: Data loader (any object with load_data method)
overwrite: Whether to overwrite existing results
"""
self.data_loader = data_loader
self.overwrite = overwrite
self.benchmark_version = data_loader.benchmark_version
def __call__(
self,
task_name: str,
split: str,
results_save_dir: str,
generator: Generator,
**kwargs
) -> None:
"""
Execute generation pipeline
This method is responsible for generation and saving only,
NOT for computing evaluation metrics.
Args:
task_name: Task name
split: Dataset split
results_save_dir: Results save directory
generator: Generator instance
**kwargs: Generation parameters
Returns:
None
"""
model_name = str(generator)
results_dir = os.path.join(
results_save_dir,
model_name,
task_name
)
os.makedirs(results_dir, exist_ok=True)
generation_file = os.path.join(results_dir, f"{split}_generated.json")
# Check if generation results already exist
if os.path.exists(generation_file) and not self.overwrite:
console.print(f"Generation results already exist, skipping: {generation_file}")
console.print("To regenerate, please set overwrite=True")
return None
start_time = time.time()
# Extract sample_size parameter (don't pass to generator)
sample_size_param = kwargs.pop('sample_size', None)
# 1. Load data
test_data = self.data_loader.load_data(task_name=task_name, split=split, sample_size=sample_size_param)
# 2. Extract prompts and references
prompts = {id: data["prompt"] for id, data in test_data.items()}
references = {id: data["ground_truth"] for id, data in test_data.items()}
# 3. Generate text (unified entry point)
# All tasks now go through the unified generate() method
# For classification tasks, target_tokens is already in kwargs from generation_config
generations, logprobs = generator.generate(prompts, **kwargs)
end_time = time.time()
total_time = end_time - start_time
num_samples = len(test_data)
avg_time_per_sample = total_time / num_samples if num_samples > 0 else 0
console.print(f"Total time: {total_time:.2f}s, Average per sample: {avg_time_per_sample:.4f}s")
# 4. Collect hardware info and MFU statistics (for MFU calculation)
console.print("[MFU DEBUG] Starting MFU data collection...")
hardware_info = None
mfu_stats = None
try:
# Check if generator has get_hardware_info method
if not hasattr(generator, 'get_hardware_info'):
console.print("[MFU ERROR] generator does NOT have get_hardware_info() method!")
console.print(f"[MFU ERROR] Generator type: {type(generator)}")
console.print(f"[MFU ERROR] Generator class: {generator.__class__.__name__}")
else:
hardware_info = generator.get_hardware_info()
if hardware_info:
console.print(f"[MFU DEBUG] GPU Model: {hardware_info.get('gpu_model')}")
console.print(f"[MFU DEBUG] GPU Count: {hardware_info.get('gpu_count')}")
console.print(f"[MFU DEBUG] GPU TFLOPs: {hardware_info.get('gpu_tflops')}")
else:
console.print("[MFU WARNING] hardware_info is None!")
# Check if generator has mfu_stats attribute
if not hasattr(generator, 'mfu_stats'):
console.print("[MFU WARNING] generator does NOT have 'mfu_stats' attribute!")
else:
mfu_stats = getattr(generator, 'mfu_stats', None)
if mfu_stats:
console.print(f"[MFU DEBUG] mfu_stats sample count: {len(mfu_stats)}")
if len(mfu_stats) > 0:
first_key = list(mfu_stats.keys())[0]
first_stats = mfu_stats[first_key]
console.print(f"[MFU DEBUG] First sample: {first_key}")
console.print(f"[MFU DEBUG] input_tokens: {first_stats.get('input_tokens', 'MISSING')}")
console.print(f"[MFU DEBUG] output_tokens: {first_stats.get('output_tokens', 'MISSING')}")
console.print(f"[MFU DEBUG] times: {first_stats.get('times', 'MISSING')}")
else:
console.print("[MFU WARNING] mfu_stats is None!")
except Exception as e:
console.print(f"Warning: Failed to collect hardware info or MFU stats: {e}", style=warning_style)
num_params_value = getattr(generator, 'num_params', None)
console.print(f"[MFU DEBUG] num_params value: {num_params_value}")
# 5. Save generation results
self.save_generations(
model_name=model_name,
task_name=task_name,
split=split,
generations=generations,
references=references,
logprobs=logprobs,
test_data=test_data,
output_path=generation_file,
total_time=total_time,
avg_time_per_sample=avg_time_per_sample,
hardware_info=hardware_info,
mfu_stats=mfu_stats,
num_params=getattr(generator, 'num_params', None),
)
console.print(f"Generation results saved to: {generation_file}")
return None
@staticmethod
def save_generations(
model_name: str,
task_name: str,
split: str,
generations: Dict[str, List[str]],
references: Dict[str, str],
logprobs: Dict[str, List[float]],
test_data: Dict[str, Dict[str, Any]],
output_path: str,
total_time: float,
avg_time_per_sample: float,
hardware_info: Optional[Dict[str, Any]] = None,
mfu_stats: Optional[Dict[str, Dict[str, List[int]]]] = None,
num_params: Optional[float] = None,
):
"""
Save generation results (excluding evaluation metrics)
Result format:
{
"model_name": "...",
"task_name": "...",
"split": "...",
"total_time": "...",
"avg_time_per_sample": "...",
"samples": {
"<sample_id>": {
"prompt": "...",
"generations": ["...", "..."],
"ground_truth": "...",
"metadata": {...} # Contains metadata from original data
},
...
}
}
"""
# Check if this is a classification task (label_pred)
is_classification_task = task_name == "label_pred"
samples: Dict[str, Any] = {}
for id, gens in generations.items():
sample_data = {
"prompt": test_data.get(id, {}).get("prompt", ""),
"generations": gens,
"ground_truth": references.get(id, ""),
}
if id in logprobs and logprobs[id]:
sample_data["logprobs"] = logprobs[id]
# Add MFU statistics for this sample (for MFU calculation)
if mfu_stats and id in mfu_stats:
sample_data["input_tokens"] = mfu_stats[id].get("input_tokens", [])
sample_data["output_tokens"] = mfu_stats[id].get("output_tokens", [])
sample_data["times"] = mfu_stats[id].get("times", [])
if is_classification_task and id in test_data:
metadata = test_data[id].get("metadata", {})
if "uid" in metadata:
sample_data["user_id"] = metadata["uid"]
if id in test_data and "metadata" in test_data[id]:
sample_data["metadata"] = test_data[id]["metadata"]
samples[id] = sample_data
data = {
"model_name": model_name,
"task_name": task_name,
"split": split,
"total_time": total_time,
"avg_time_per_sample": avg_time_per_sample,
"samples": samples,
}
# Add hardware info and token statistics (for MFU calculation)
if hardware_info:
data["hardware_info"] = hardware_info
else:
console.print("[MFU DEBUG] ❌ Skipping hardware_info (None or empty)")
if num_params:
data["num_params"] = num_params
else:
console.print("[MFU DEBUG] ❌ Skipping num_params (None or 0)")
# Save mfu_stats_aggregate for multi-stage MFU calculation
# Compute aggregate statistics from per-sample mfu_stats
if mfu_stats:
# Determine number of stages from first sample
num_stages = 0
for sample_stats in mfu_stats.values():
num_stages = len(sample_stats.get("input_tokens", []))
console.print(f"[MFU DEBUG] Determined num_stages: {num_stages}")
break
# New structure: dict with lists instead of array of dicts
data["mfu_stats_aggregate"] = {
"total_input_tokens": [],
"total_output_tokens": [],
"total_time": []
}
for stage_idx in range(num_stages):
total_input_tokens = 0
total_output_tokens = 0
# Aggregate token stats across all samples for this stage
for sample_stats in mfu_stats.values():
input_tokens_list = sample_stats.get("input_tokens", [])
output_tokens_list = sample_stats.get("output_tokens", [])
if stage_idx < len(input_tokens_list):
total_input_tokens += input_tokens_list[stage_idx]
if stage_idx < len(output_tokens_list):
total_output_tokens += output_tokens_list[stage_idx]
# Calculate stage time as max across all samples
# Ray workers run in parallel, so stage time = slowest worker time
stage_times = []
for sample_stats in mfu_stats.values():
times_list = sample_stats.get("times", [])
if stage_idx < len(times_list):
stage_times.append(times_list[stage_idx])
# Use max time if available, otherwise 0.0
stage_time = max(stage_times) if stage_times else 0.0
data["mfu_stats_aggregate"]["total_input_tokens"].append(total_input_tokens)
data["mfu_stats_aggregate"]["total_output_tokens"].append(total_output_tokens)
data["mfu_stats_aggregate"]["total_time"].append(stage_time)
else:
console.print("[MFU DEBUG] ❌ Skipping mfu_stats processing (None or empty)")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
================================================
FILE: benchmarks/benchmark/gpu_utils.py
================================================
"""
GPU hardware detection and FLOPS calculation utilities for MFU computation.
"""
from typing import Dict, Any, Optional
from benchmark.console import console
# GPU theoretical peak FLOPS (TFLOPS) for BF16/FP16
# Source: Official vendor specifications
GPU_TFLOPS_MAP = {
# NVIDIA A100 series
"A100-SXM4-40GB": 312.0,
"A100-SXM4-80GB": 312.0,
"A100-PCIE-40GB": 312.0,
"A100-PCIE-80GB": 312.0,
# NVIDIA A800 series (China-specific A100 variant)
"A800-SXM4-80GB": 312.0,
"A800": 312.0,
# NVIDIA H100 series
"H100-SXM5-80GB": 989.0,
"H100-PCIE-80GB": 756.0,
"H100": 989.0,
# NVIDIA V100 series
"V100-SXM2-16GB": 125.0,
"V100-SXM2-32GB": 125.0,
"V100-PCIE-16GB": 112.0,
"V100-PCIE-32GB": 112.0,
# NVIDIA A40
"A40": 149.7,
# NVIDIA A30
"A30": 165.0,
# NVIDIA A10
"A10": 125.0,
# NVIDIA RTX series
"RTX 4090": 82.6,
"RTX 4080": 48.7,
"RTX 3090": 35.6,
"RTX 3080": 29.8,
}
def _normalize_gpu_name(gpu_name: str) -> str:
"""
Normalize GPU name for lookup in TFLOPS map.
Args:
gpu_name: Raw GPU name from torch.cuda
Returns:
Normalized GPU name
"""
gpu_name = gpu_name.strip()
# Try exact match first
if gpu_name in GPU_TFLOPS_MAP:
return gpu_name
# Try fuzzy matching
gpu_name_upper = gpu_name.upper()
# Match A100 variants
if "A100" in gpu_name_upper:
if "80GB" in gpu_name_upper or "80G" in gpu_name_upper:
return "A100-SXM4-80GB"
else:
return "A100-SXM4-40GB"
# Match A800
if "A800" in gpu_name_upper:
return "A800"
# Match H100 variants
if "H100" in gpu_name_upper:
if "PCIE" in gpu_name_upper or "PCIe" in gpu_name_upper:
return "H100-PCIE-80GB"
else:
return "H100-SXM5-80GB"
# Match V100 variants
if "V100" in gpu_name_upper:
if "32GB" in gpu_name_upper or "32G" in gpu_name_upper:
return "V100-SXM2-32GB"
else:
return "V100-SXM2-16GB"
# Match other GPUs
for known_gpu in GPU_TFLOPS_MAP.keys():
if known_gpu.upper() in gpu_name_upper:
return known_gpu
return gpu_name
def get_gpu_tflops(gpu_name: str) -> Optional[float]:
"""
Get theoretical peak TFLOPS for a given GPU model.
Args:
gpu_name: GPU model name
Returns:
TFLOPS value for BF16/FP16, or None if unknown
"""
normalized_name = _normalize_gpu_name(gpu_name)
return GPU_TFLOPS_MAP.get(normalized_name)
def get_gpu_info() -> Dict[str, Any]:
"""
Detect GPU hardware information using PyTorch.
Returns:
Dictionary containing:
- gpu_available: bool, whether GPU is available
- gpu_count: int, number of GPUs
- gpu_model: str, GPU model name
- gpu_memory_total_gb: float, total GPU memory in GB
- gpu_tflops: float, theoretical peak TFLOPS for BF16/FP16
"""
try:
import torch
except ImportError:
console.print("PyTorch not available, cannot detect GPU info")
return {
"gpu_available": False,
"gpu_count": 0,
"gpu_model": "unknown",
"gpu_memory_total_gb": 0.0,
"gpu_tflops": None,
}
if not torch.cuda.is_available():
console.print("CUDA not available")
return {
"gpu_available": False,
"gpu_count": 0,
"gpu_model": "unknown",
"gpu_memory_total_gb": 0.0,
"gpu_tflops": None,
}
gpu_count = torch.cuda.device_count()
# Get properties of the first GPU (assume homogeneous cluster)
gpu_props = torch.cuda.get_device_properties(0)
gpu_model = gpu_props.name
gpu_memory_total_gb = gpu_props.total_memory / (1024 ** 3) # Convert bytes to GB
# Get TFLOPS
gpu_tflops = get_gpu_tflops(gpu_model)
if gpu_tflops is None:
console.print(
f"Unknown GPU model '{gpu_model}', cannot determine TFLOPS. "
f"Please add it to GPU_TFLOPS_MAP in gpu_utils.py"
)
gpu_info = {
"gpu_available": True,
"gpu_count": gpu_count,
"gpu_model": gpu_model,
"gpu_memory_total_gb": round(gpu_memory_total_gb, 2),
"gpu_tflops": gpu_tflops,
}
console.print(f"Detected GPU: {gpu_model} x {gpu_count}, {gpu_tflops} TFLOPS (BF16/FP16)")
return gpu_info
================================================
FILE: benchmarks/benchmark/tasks/__init__.py
================================================
"""
Tasks definition for Benchmark
"""
from .tasks import (
BenchmarkTable,
check_benchmark_version,
check_task_types,
check_splits,
LATEST_BENCHMARK_VERSION,
)
__all__ = [
"BenchmarkTable",
"check_benchmark_version",
"check_task_types",
"check_splits",
"LATEST_BENCHMARK_VERSION",
]
================================================
FILE: benchmarks/benchmark/tasks/tasks.py
================================================
"""
Task table and utility functions for Benchmark
"""
from typing import List, Optional, Tuple
from benchmark.tasks.v1_0.registry import TaskTable as TaskTable_v1_0
LATEST_BENCHMARK_VERSION = "v1.0"
BenchmarkTable = {
"v1.0": TaskTable_v1_0,
}
def get_available_benchmark_versions() -> List[str]:
"""Get all available benchmark versions"""
return sorted(list(BenchmarkTable.keys()))
def get_available_task_types(benchmark_version: str = LATEST_BENCHMARK_VERSION) -> List[str]:
"""Get all task types for the specified version"""
task_table = BenchmarkTable[benchmark_version]
return sorted(list(task_table.keys()))
def get_available_domains(benchmark_version: str = LATEST_BENCHMARK_VERSION) -> List[str]:
"""Get all domains for the specified version"""
domains = set()
for task_table in BenchmarkTable[benchmark_version].values():
for domain in task_table.keys():
domains.add(domain)
return sorted(list(domains))
def get_available_languages(benchmark_version: str = LATEST_BENCHMARK_VERSION) -> List[str]:
"""Get all languages for the specified version"""
languages = set()
for task_table in BenchmarkTable[benchmark_version].values():
for task in task_table.values():
for lang in task.keys():
languages.add(lang)
return sorted(list(languages))
def check_benchmark_version(benchmark_version: Optional[str]) -> str:
"""
Validate if benchmark version is valid
Args:
benchmark_version: Version to validate, returns latest version if None
Returns:
str: Valid benchmark version
Raises:
ValueError: If version is invalid
"""
if benchmark_version is None:
benchmark_version = LATEST_BENCHMARK_VERSION
else:
available_benchmark_versions = get_available_benchmark_versions()
if benchmark_version not in available_benchmark_versions:
raise ValueError(
f"Invalid benchmark version: {benchmark_version}. Available versions: {', '.join(available_benchmark_versions)}"
)
return benchmark_version
def check_task_types(
task_types: Optional[List[str]],
benchmark_version: str = LATEST_BENCHMARK_VERSION,
) -> List[str]:
"""
Validate if task types are valid
Args:
task_types: List of task types to validate, returns all task types if None
benchmark_version: Benchmark version
Returns:
List[str]: Valid task types list
Raises:
ValueError: If task type is invalid
"""
available_task_types = get_available_task_types(benchmark_version)
if task_types is None:
task_types = available_task_types
else:
if isinstance(task_types, str):
task_types = [task_types]
task_types = sorted(list(set(task_types)))
task_types = [task_type.lower() for task_type in task_types]
for task_type in task_types:
if task_type not in available_task_types:
raise ValueError(
f"{benchmark_version} | Invalid task type: {task_type}. Available task types: {', '.join(available_task_types)}"
)
return task_types
def check_splits(
splits: Optional[List[str]],
benchmark_version: str = LATEST_BENCHMARK_VERSION,
) -> List[str]:
"""
Validate if dataset splits are valid
Args:
splits: List of splits to validate, returns all splits if None
benchmark_version: Benchmark version
Returns:
List[str]: Valid splits list
Raises:
ValueError: If split is invalid
"""
# Only allow test split
available_splits = ["test"]
if splits is None:
splits = available_splits
else:
if isinstance(splits, str):
splits = [splits]
splits = sorted(list(set(splits)))
splits = [split.lower() for split in splits]
for split in splits:
if split not in available_splits:
raise ValueError(
f"{benchmark_version} | Invalid split: {split}. Available splits: {', '.join(available_splits)}"
)
return splits
================================================
FILE: benchmarks/benchmark/tasks/v1_0/__init__.py
================================================
"""
v1.0 Version Task Definitions
"""
from .registry import TaskTable
__all__ = ["TaskTable"]
================================================
FILE: benchmarks/benchmark/tasks/v1_0/base_evaluator.py
================================================
"""
Base Evaluator for all task evaluators
Provides common interface for evaluation logic.
"""
import json
import os
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, Optional, List
from benchmark.console import console, success_style
class BaseEval(ABC):
"""Base class for all task evaluators"""
def __init__(
self,
samples: Dict[str, Dict[str, Any]],
task_name: Optional[str] = None,
predictions_dir: Optional[str] = None,
debug: bool = False,
task_config: Optional[Dict[str, Any]] = None,
data_dir: Optional[str] = None,
overwrite: bool = False,
cached_metrics: Optional[Dict[str, Any]] = None
):
"""
Initialize base evaluator
Args:
samples: Dictionary of samples from test_generated.json
Format: {
sample_id: {
"prompt": "...",
"generations": ["..."],
"ground_truth": "...",
"metadata": {...}
}
}
task_name: Task name (e.g., "math_500")
predictions_dir: Directory to save debug files (optional)
debug: Whether to save debug information
task_config: Task configuration dictionary (optional)
data_dir: Data directory path (optional)
overwrite: Whether to overwrite existing metrics and recompute from scratch
cached_metrics: Existing overall metrics from eval_results (optional)
"""
self.samples = samples
self.task_name = task_name
self.predictions_dir = predictions_dir
self.debug = debug
self.task_config = task_config or {}
self.data_dir = data_dir
self.overwrite = overwrite
self.cached_metrics = cached_metrics or {}
def evaluate(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
"""
Evaluate the samples and return metrics
This method provides a simplified two-level caching-aware evaluation flow:
1. If overwrite=True, always recompute from scratch
2. If cached overall metrics exist in eval_results, return them with empty per_sample_metrics
3. Otherwise, compute from scratch
Subclasses should override:
- required_metrics property: Return list of overall metric names
- _compute_metrics_from_scratch(): Compute all metrics from scratch
Returns:
Tuple of (metrics, per_sample_metrics)
"""
# If overwrite=True, always recompute from scratch
if self.overwrite:
console.print("[cyan]Overwrite=True, recomputing all metrics from scratch...[/cyan]")
return self._compute_metrics_from_scratch()
# If cached overall metrics exist, use them
if self._has_all_required_metrics():
console.print("[cyan]Using existing overall metrics from eval_results...[/cyan]")
# Return cached metrics with empty per_sample_metrics (not needed when using cache)
return self.cached_metrics, {}
# Otherwise, compute from scratch
console.print("[cyan]Computing metrics from scratch...[/cyan]")
return self._compute_metrics_from_scratch()
def _all_samples_have_keys(self, required_keys: List[str]) -> bool:
"""Check if all samples have required keys"""
for sample in self.samples.values():
for key in required_keys:
if key not in sample:
return False
return True
@property
def required_metrics(self) -> Optional[List[str]]:
"""Define required overall metric keys"""
return None
def _has_all_required_metrics(self) -> bool:
"""Check if cached_metrics contains all required keys (override for custom logic)"""
if self.required_metrics is not None:
return all(key in self.cached_metrics for key in self.required_metrics)
return False
def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
"""Compute metrics from scratch (override in subclasses)"""
raise NotImplementedError("Subclasses must implement _compute_metrics_from_scratch()")
def _save_debug_json(
self,
debug_info: Dict[str, Any],
filename: str = "debug.json"
) -> Optional[str]:
"""Save debug information to JSON file"""
if not self.predictions_dir:
return None
debug_filename = os.path.join(self.predictions_dir, filename)
os.makedirs(os.path.dirname(debug_filename), exist_ok=True)
with open(debug_filename, 'w', encoding='utf-8') as f:
json.dump(debug_info, f, indent=2, ensure_ascii=False)
console.print(f"✓ Debug information saved to: {debug_filename}", style=success_style)
return debug_filename
================================================
FILE: benchmarks/benchmark/tasks/v1_0/base_loader.py
================================================
"""
Base Loader for all task data loaders
Provides common functionality for data loading, sampling, and file path resolution.
"""
import os
import json
import pandas as pd
from typing import Dict, Any, Optional
from abc import ABC
from benchmark.console import *
class BaseLoader(ABC):
"""Base class for all task data loaders"""
def __init__(
self,
task_config: Dict[str, Any],
data_dir: Optional[str] = None,
tokenizer: Optional[Any] = None,
enable_thinking: Optional[bool] = None,
):
"""Initialize base loader"""
self.task_config = task_config
self.data_dir = data_dir
self.tokenizer = tokenizer
self.enable_thinking = enable_thinking
self.task_name = task_config.get("name", "unknown")
# Validate tokenizer is provided for messages-based format
if self.tokenizer is None:
raise ValueError(
f"{self.task_name} requires tokenizer for messages-based format. "
f"Please provide model_path when initializing Benchmark.\n"
f"Example: Benchmark(task_types=['{self.task_name}'], model_path='your-model-path')"
)
def load_data(self, split: str = "test", sample_size: Optional[Any] = None) -> Dict[str, Dict[str, Any]]:
"""
Load data for the task in messages-based format
Args:
split: Dataset split (default "test")
sample_size: Override sample size (can be int, "full", or None to use task config)
Returns:
Dictionary mapping sample_id to sample data:
{
sample_id: {
"prompt": "formatted prompt from apply_chat_template",
"ground_truth": "answer",
"metadata": {
"row_index": idx,
"messages": [...]
}
}
}
"""
# Determine effective sample size
if sample_size is not None:
if sample_size == "full":
effective_sample_size = self.task_config.get("size")
else:
effective_sample_size = int(sample_size)
else:
effective_sample_size = self.task_config.get("sample_size")
full_size = self.task_config.get("size")
# Try to load cached sample dataframe
df = None
if effective_sample_size is not None and full_size is not None and effective_sample_size < full_size:
df = self._load_sample_dataframe(split, effective_sample_size)
# If no cache, load and sample original data
if df is None:
df = self._load_dataframe(split)
# Perform sampling if needed
if effective_sample_size is not None and effective_sample_size < len(df):
df = self._sample_data(df, effective_sample_size)
# Save sampled data
if full_size is not None and effective_sample_size < full_size:
self._save_sample_data(df, split, effective_sample_size)
if 'messages' not in df.columns:
raise ValueError(
f"{self.task_name} requires 'messages' column in data file. "
f"Found columns: {list(df.columns)}\n"
f"Please ensure your data is in messages-based format."
)
if 'metadata' not in df.columns:
raise ValueError(
f"{self.task_name} requires 'metadata' column in data file. "
f"Found columns: {list(df.columns)}\n"
f"Please ensure your data is in messages-based format."
)
console.print(f"[green]Processing {self.task_name} data in messages-based format[/green]")
result = self._process_dataframe(df)
return result
@staticmethod
def _is_empty_value(value) -> bool:
"""Check if a value is None, NaN, or empty"""
if value is None:
return True
if isinstance(value, float):
try:
return pd.isna(value)
except (ValueError, TypeError):
return False
if isinstance(value, str):
return len(value.strip()) == 0
try:
if hasattr(value, '__len__'):
return len(value) == 0
except (ValueError, TypeError):
pass
return False
@staticmethod
def _convert_messages_format(messages: list) -> list:
"""
Convert message format.
{"role": "user", "content": [{"type": "text", "text": "..."}]}
->
{"role": "user", "content": "..."}
"""
converted = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
# Extract text from content list
text_parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text_parts.append(item.get("text", ""))
converted.append({
"role": msg.get("role"),
"content": "".join(text_parts)
})
else:
# Already in old format
converted.append(msg)
return converted
def _load_custom_chat_template(self):
"""Load custom chat template based on configuration"""
if not self.tokenizer:
return
prompt_config = self.task_config.get("prompt_config", {})
custom_template = prompt_config.get("custom_chat_template")
template_path = os.path.join(
os.path.dirname(__file__),
custom_template
)
if not os.path.exists(template_path):
raise FileNotFoundError(f"✗ Custom chat template not found: {template_path}")
with open(template_path, "r", encoding="utf-8") as f:
self.tokenizer.chat_template = f.read()
console.print(f"✓ Loaded custom chat template: {custom_template}", style=success_style)
def _get_data_file_path(self, split: str) -> str:
"""Get data file path for the given split"""
if self.data_dir:
base_dir = self.data_dir
else:
base_dir = "./data"
filename = f"{self.task_name}_{split}.parquet"
possible_paths = [
os.path.join(base_dir, self.task_name, filename),
]
for file_path in possible_paths:
if os.path.exists(file_path):
return file_path
return possible_paths[0]
def _get_sample_data_file_path(self, split: str, sample_size: int) -> str:
"""Get sample data file path"""
if self.data_dir:
base_dir = self.data_dir
else:
base_dir = "./data"
possible_paths = [
os.path.join(base_dir, self.task_name, f"{self.task_name}_{split}_sample_{sample_size}.parquet"),
os.path.join(base_dir, f"{self.task_name}_{split}_sample_{sample_size}.parquet"),
]
for path in possible_paths:
if os.path.exists(path):
return path
return possible_paths[0]
def _load_dataframe(self, split: str) -> pd.DataFrame:
"""Load DataFrame from data file"""
data_file = self._get_data_file_path(split)
if not os.path.exists(data_file):
raise FileNotFoundError(f"Data file not found: {data_file}")
console.print(f"Loading data file: {data_file}")
if data_file.endswith('.parquet'):
df = pd.read_parquet(data_file)
else:
raise ValueError(f"Unsupported file format: {data_file}")
return df
def _sample_data(self, df: pd.DataFrame, sample_size: int) -> pd.DataFrame:
"""Sample data from DataFrame"""
if sample_size >= len(df):
return df
console.print(f"Sampling {sample_size} samples (total: {len(df)})")
return df.head(sample_size)
def _save_sample_data(
self,
df: pd.DataFrame,
split: str,
sample_size: int
):
"""Save sample data in parquet format"""
sample_file = self._get_sample_data_file_path(split, sample_size)
sample_dir = os.path.dirname(sample_file)
if sample_dir:
os.makedirs(sample_dir, exist_ok=True)
df.to_parquet(sample_file, index=False)
console.print(f"Sample data saved to: {sample_file}")
def _load_sample_dataframe(self, split: str, sample_size: int) -> Optional[pd.DataFrame]:
"""Load sample dataframe from cache if exists"""
sample_file = self._get_sample_data_file_path(split, sample_size)
if not os.path.exists(sample_file):
return None
console.print(f"Loading sample data from cache: {sample_file}")
df = pd.read_parquet(sample_file)
return df
def _process_dataframe(self, df: pd.DataFrame) -> Dict[str, Dict[str, Any]]:
"""Process DataFrame and convert to model input format"""
self._load_custom_chat_template()
result = {}
prompt_config = self.task_config.get("prompt_config", {})
# Command-line parameter has higher priority than config
if self.enable_thinking is not None:
enable_thinking = self.enable_thinking
else:
enable_thinking = prompt_config.get("enable_thinking", False)
console.print(f"[cyan]Auto Thinking: {'✓ Enabled' if enable_thinking else '✗ Disabled'}[/cyan]")
for idx, row in df.iterrows():
sample_id = str(idx)
messages = row.get('messages')
if self._is_empty_value(messages):
console.print(f"Sample {sample_id}: messages is empty, skipping")
continue
if isinstance(messages, str):
try:
messages = json.loads(messages)
except Exception:
console.print(f"Sample {sample_id}: failed to parse messages, skipping")
continue
messages = self._convert_messages_format(messages)
try:
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except Exception as e:
console.print(f"Sample {sample_id}: failed to apply chat template: {e}, skipping")
continue
metadata_raw = row.get('metadata')
if self._is_empty_value(metadata_raw):
console.print(f"Sample {sample_id}: metadata is empty, skipping")
continue
if isinstance(metadata_raw, str):
try:
metadata_dict = json.loads(metadata_raw)
except Exception:
console.print(f"Sample {sample_id}: failed to parse metadata, skipping")
continue
elif isinstance(metadata_raw, dict):
metadata_dict = metadata_raw
else:
console.print(f"Sample {sample_id}: invalid metadata format, skipping")
continue
answer = metadata_dict.get('answer')
if self._is_empty_value(answer):
console.print(f"Sample {sample_id}: answer is empty in metadata, skipping")
continue
ground_truth_str = str(answer).strip()
result_item = {
"prompt": formatted_prompt,
"ground_truth": ground_truth_str,
"metadata": self._make_metadata_serializable(idx, metadata_dict)
}
result[sample_id] = result_item
console.print(f"[green]Loaded {len(result)} samples for {self.task_name}[/green]")
return result
def _make_metadata_serializable(
self,
idx: Any,
metadata_dict: dict,
) -> dict:
"""Convert metadata to JSON-serializable format"""
del metadata_dict["answer"]
metadata = {
"row_index": int(idx) if hasattr(idx, '__int__') else str(idx),
**metadata_dict,
}
return metadata
================================================
FILE: benchmarks/benchmark/tasks/v1_0/item_understand/__init__.py
================================================
"""
Item Understand Task Module
"""
from .config import ITEM_UNDERSTAND_CONFIG
from .evaluator import ItemUnderstandEvaluator
from . import utils
__all__ = [
"ITEM_UNDERSTAND_CONFIG",
"ItemUnderstandEvaluator",
"utils",
]
================================================
FILE: benchmarks/benchmark/tasks/v1_0/item_understand/config.py
================================================
"""
Item Understand Task Configuration
"""
# Item Understand Task Configuration
ITEM_UNDERSTAND_CONFIG = {
"name": "item_understand",
"source": "Kuaishou Internal",
"splits": ["test"],
"size": 500,
"sample_size": 500,
"description": "Video SID to Caption generation task",
"data_fields": {
"messages_field": "messages",
"metadata_field": "metadata",
},
"prompt_config": {
"enable_thinking": False, # Enable thinking mode for apply_chat_template
"custom_chat_template": "qwen3_soft_switch.jinja2", # Custom jinja2 template (file in v1_0 directory)
},
# Generation parameter configuration
"generation_config": {
"num_return_sequences": 1,
"max_new_tokens": 128,
"temperature": 0.01,
"top_p": 0.95,
"repetition_penalty": 1.0,
"do_sample": False,
"num_return_thinking_sequences": 1,
"max_new_thinking_tokens": 1000,
},
"evaluation_config": {
"metrics": ["macro_wip_double_weighted_f1", "micro_wip_double_weighted_f1"],
"bertscore_model_type": "bert-base-chinese",
"bertscore_num_layers": 9,
"bertscore_lang": "zh",
# WIP (Weighted Information Points) evaluation config
"wip_enabled": True, # Whether to enable WIP evaluation
"wip_judge_model": "gemini", # Judge LLM type: gemini/deepseek/claude
"wip_max_workers": 1, # Concurrent workers for LLM calls
"wip_core_threshold": 5, # Core threshold for importance score (1-5)
"wip_max_samples": 500, # Max samples to evaluate (None for all)
}
}
================================================
FILE: benchmarks/benchmark/tasks/v1_0/item_understand/evaluator.py
================================================
"""
Item Understand Evaluator
Evaluates model predictions on Item Understand task using WIP (LLM-as-Judge).
"""
import os
from typing import Dict, Any, Tuple, List
from benchmark.console import console
from benchmark.tasks.v1_0.base_evaluator import BaseEval
class ItemUnderstandEvaluator(BaseEval):
"""Item Understand task evaluator"""
@property
def required_metrics(self) -> List[str]:
"""Define required overall metrics for Item Understand evaluation"""
return ["macro_wip_double_weighted_f1"]
def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
"""
Compute all metrics from scratch
Returns:
Tuple of (metrics, per_sample_metrics)
"""
total_samples = len(self.samples)
# Prepare data for evaluation
sample_ids = list(self.samples.keys())
predictions = []
references = []
for sample_id in sample_ids:
sample = self.samples[sample_id]
# Get ground truth
ground_truth = sample.get("ground_truth", "")
references.append(ground_truth)
# Get model prediction (first generation)
generations = sample.get("generations", [])
if not generations:
prediction = ""
else:
prediction = generations[0]
predictions.append(prediction)
# Get evaluation config
eval_config = self.task_config.get("evaluation_config", {})
# Build per-sample metrics
per_sample_metrics = {}
for sample_id in sample_ids:
per_sample_metrics[sample_id] = {}
# Build overall metrics
metrics = {
"num_samples": total_samples
}
# WIP Evaluation (if enabled)
wip_enabled = eval_config.get("wip_enabled", False)
if wip_enabled:
console.print("[cyan]WIP evaluation enabled, starting LLM-as-Judge evaluation...[/cyan]")
wip_metrics, wip_per_sample = self._evaluate_wip(
sample_ids=sample_ids,
predictions=predictions,
references=references,
eval_config=eval_config
)
# Merge WIP metrics into overall metrics
metrics.update(wip_metrics)
# Merge WIP per-sample metrics
for sample_id in sample_ids:
if sample_id in wip_per_sample:
per_sample_metrics[sample_id].update(wip_per_sample[sample_id])
# Save debug information if requested
if self.debug and self.predictions_dir:
self._save_debug_info(metrics, per_sample_metrics, predictions, references)
return metrics, per_sample_metrics
def _evaluate_wip(
self,
sample_ids: list,
predictions: list,
references: list,
eval_config: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
"""
Perform WIP (Weighted Information Points) evaluation using LLM-as-Judge.
Args:
sample_ids: List of sample IDs
predictions: List of prediction texts
references: List of reference texts
eval_config: Evaluation configuration
Returns:
Tuple of (wip_metrics, wip_per_sample_metrics)
"""
try:
from api import get_client_from_config
from benchmark.tasks.v1_0.item_understand.utils import evaluate_wip
except ImportError as e:
console.print(f"[red]Failed to import WIP evaluation modules: {e}[/red]")
return {}, {}
# Get WIP config
wip_judge_model = eval_config.get("wip_judge_model", "deepseek")
wip_max_workers = eval_config.get("wip_max_workers", 5)
wip_max_samples = eval_config.get("wip_max_samples", 100)
wip_core_threshold = eval_config.get("wip_core_threshold", 5)
wip_gt_cache_dir = os.path.join(self.data_dir, self.task_name) # Use data_dir / task_name as GT cache directory
# Use BERTScore config from evaluation_config (not separate wip config)
bertscore_model = eval_config.get("bertscore_model_type", "bert-base-chinese")
bertscore_num_layers = eval_config.get("bertscore_num_layers", 9)
# Create LLM client
try:
llm_client = get_client_from_config(wip_judge_model)
console.print(f"[green]Using {wip_judge_model} as WIP judge[/green]")
except Exception as e:
console.print(f"[red]Failed to create LLM client for WIP evaluation: {e}[/red]")
return {}, {}
# Prepare data as dicts
predictions_dict = {id: pred for id, pred in zip(sample_ids, predictions)}
references_dict = {id: ref for id, ref in zip(sample_ids, references)}
# Get model name for cache file naming
# Try to extract from llm_client config
model_name = getattr(llm_client, 'model_name', wip_judge_model)
# Run WIP evaluation
try:
wip_metrics, wip_per_sample = evaluate_wip(
predictions=predictions_dict,
references=references_dict,
llm_client=llm_client,
max_workers=wip_max_workers,
max_samples=wip_max_samples,
gt_cache_dir=wip_gt_cache_dir,
model_name=model_name,
save_dir=self.predictions_dir,
bertscore_model=bertscore_model,
bertscore_num_layers=bertscore_num_layers,
core_threshold=wip_core_threshold,
)
console.print(f"[green]WIP evaluation completed: {wip_metrics.get('wip_num_samples', 0)} samples evaluated[/green]")
return wip_metrics, wip_per_sample
except Exception as e:
console.print(f"[red]WIP evaluation failed: {e}[/red]")
import traceback
traceback.print_exc()
return {}, {}
def _save_debug_info(
self,
metrics: Dict[str, Any],
per_sample_metrics: Dict[str, Dict[str, Any]],
predictions: list,
references: list
):
"""
Save detailed debug information to file
Args:
metrics: Overall metrics
per_sample_metrics: Per-sample metrics
predictions: List of predictions
references: List of references
"""
# Prepare debug info
debug_info = {
"overall_metrics": metrics,
"per_sample_metrics": per_sample_metrics,
"sample_count": len(predictions),
}
# Add some examples
sample_ids = list(self.samples.keys())
debug_info["examples"] = []
for i in range(min(10, len(sample_ids))):
sample_id = sample_ids[i]
debug_info["examples"].append({
"sample_id": sample_id,
"prediction": predictions[i],
"reference": references[i],
"wip_unweighted_f1": per_sample_metrics[sample_id].get("wip_unweighted_f1"),
"wip_unweighted_core_f1": per_sample_metrics[sample_id].get("wip_unweighted_core_f1"),
"wip_importance_weighted_f1": per_sample_metrics[sample_id].get("wip_importance_weighted_f1"),
"wip_importance_weighted_core_f1": per_sample_metrics[sample_id].get("wip_importance_weighted_core_f1"),
"wip_double_weighted_f1": per_sample_metrics[sample_id].get("wip_double_weighted_f1"),
"wip_double_weighted_core_f1": per_sample_metrics[sample_id].get("wip_double_weighted_core_f1"),
})
# Save to file using base class method
self._save_debug_json(debug_info, filename="debug.json")
# Print summary statistics
console.print(f"Total samples: {metrics['num_samples']}")
# Print WIP metrics if available
if metrics.get('macro_wip_unweighted_f1') is not None:
console.print(f"Macro WIP Unweighted F1: {metrics['macro_wip_unweighted_f1']:.4f}")
if metrics.get('macro_wip_double_weighted_f1') is not None:
console.print(f"Macro WIP Double-weighted F1: {metrics['macro_wip_double_weighted_f1']:.4f}")
================================================
FILE: benchmarks/benchmark/tasks/v1_0/item_understand/utils.py
================================================
import json
import os
import re
from typing import Dict, List, Any, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from benchmark.console import console
WIP_EXTRACTION_PROMPT = """你是一位顶级的【信息抽取专家】,擅长从非结构化的文本中解析出结构化的信息。
### 你的核心任务
你的任务是分析我提供的描述性文字,并将其分解为结构化的【原子化且唯一】的"信息点"列表。
### 输出结构
对于列表中的每一个信息点,你必须提供:
1. **info_point**: 一个简洁的、陈述事实的短语。
2. **importance_score**: 一个 [1, 5] 之间的【整数】,代表该信息点的重要性。
---
### 关键原则 (必须遵守)
1. **原子性 (Atomic):** 每个 `info_point` 应只包含一个独立的事实。
* (好): `{{"info_point": "女孩在吃饭", "importance_score": 4}}`
* (差): `{{"info_point": "女孩在吃饭,妈妈在旁边看", "importance_score": 4}}`
2. **唯一性 (Unique):** 确保你提取的每个 `info_point` 都是**概念上唯一**的。
3. **合并 (Consolidate):** 如果原始文本中的多个短语描述的是【同一个核心思想】,你【必须】将它们合并成一个单一的、最具代表性的 `info_point`。
* (例如): 如果文本说 "活动环境是温馨的" 和 "视频色彩营造温馨氛围",你应该只提取一个,如:`{{"info_point": "视频氛围温馨", "importance_score": 5}}`。
* **不要创建重复或语义高度重叠的条目。**
---
### 打分指南 (1-5分制)
* **5分 (绝对核心):** 视频的"灵魂"。如果缺少这个点,整个摘要就毫无意义。(例如:"如何制作煎蛋卷"、"XX游戏的评测")
* **4分 (关键信息):** 视频的"骨架"。关键的事件、步骤或场景。(例如:"打散三个鸡蛋"、"使用了不粘锅"、"游戏画面评测")
* **3分 (重要细节):** 视频的"肉"。支撑骨架的具体、重要的细节。(例如:"加入了盐和胡椒"、"用中火加热黄油"、"角色动作流畅")
* **2分 (补充细节):** 补充性的上下文或次要信息。(例如:"煎蛋卷折叠了三次"、"背景音乐很好听")
* **1分 (琐碎信息):** 琐碎的、风格化的或背景性的描述。(例如:"主持人穿着蓝色围裙"、"视频光线很好")
---
### 格式与示例
你的输出必须是【纯粹的 JSON 格式】,可以被 `json.loads` 直接解析。JSON应包含一个 "wips" 键,其值为一个列表。如果文本中没有可提取的信息点,请返回 `{{"wips": []}}`。
**[示例输入]**
这是一段关于如何制作法式煎蛋卷的教程视频。主持人首先将三个鸡蛋打入碗中,并加入了盐和一小撮胡椒进行搅拌。视频强调了使用中火和不粘锅的重要性。接着,她在锅中融化了一块黄油,然后倒入蛋液。在烹饪过程中,她不断晃动平底锅,并将边缘的蛋液推向中心。最后,她将煎蛋卷折叠成三折,盛入盘中。整个过程非常快速。
**[示例输出]**
```json
{{
"wips": [
{{
"info_point": "教程:如何制作法式煎蛋卷",
"importance_score": 5
}},
{{
"info_point": "使用三个鸡蛋,加盐和胡椒搅拌",
"importance_score": 3
}},
{{
"info_point": "强调使用中火",
"importance_score": 4
}},
{{
"info_point": "使用不粘锅和黄油",
"importance_score": 4
}},
{{
"info_point": "晃动锅并将蛋液边缘推向中心",
"importance_score": 3
}},
{{
"info_point": "煎蛋卷被折叠成三折",
"importance_score": 2
}},
{{
"info_point": "烹饪过程快速",
"importance_score": 1
}}
]
}}
```
现在,请开始分析我提供的描述性文字:
{}
你的输出结果 (请严格按照上述要求返回一个格式规整的 JSON,可以被 json.loads 直接解析。请不要在 JSON 数据前后添加任何额外的解释性文字或代码块标记): """
WIP_MATCHING_PROMPT = """你是一位极其严谨的**语义匹配专家**。你的任务是精确地对比两组关于同一个视频摘要的结构化信息点 (WIPs),并找出它们之间的匹配关系。
**背景信息:**
- **Ground Truth WIPs (GT列表)**: 这是视频摘要的"事实标准",代表视频中真实存在的所有核心信息。每个点都有一个 [1-5] 的重要性分数 (`importance_score`)。
- **Model-Generated WIPs (模型列表)**: 这是由一个AI模型生成的摘要信息点,代表它"声称"在视频中看到的内容。每个点也有一个 [1-5] 的重要性分数。
**你的核心任务:**
对比这两个列表,并输出一个包含三类结果的JSON对象:
1. **`matches`**: 一个匹配对的列表。对于"模型列表"中的每一个项,如果在"GT列表"中找到了一个**语义上非常相似**的对应项,就将它们配对。
2. **`unmatched_model_wips` (幻觉)**: "模型列表"中,那些在"GT列表"里找不到任何合理对应项的条目。这些代表了模型的**幻觉 (False Positives)**。
3. **`unmatched_gt_wips` (漏报)**: "GT列表"中,那些没有被"模型列表"中任何条目匹配到的条目。这些代表了模型的**漏报 (False Negatives)**。
**至关重要的匹配规则:**
1. **语义核心**: 匹配的核心是 `info_point` 的语义。
2. **部分匹配**: 如果两个 `info_point` 语义上"部分重叠"但"不完全相同",你【也应该】将它们匹配。
* (例如): GT的 `"一场激烈精彩的篮球比赛"` 和 Gen的 `"球员在打篮球"` 应该被【匹配】(因为核心"篮球"匹配上了)。
* (例如): GT的 `"评测《魔龙巢穴:暗影崛起》"` 和 Gen的 `"评测《魔龙巢穴:冰封王座》"` 应该被【匹配】(因为核心"《魔龙巢穴》评测"匹配上了)。
3. **一对一匹配**: 找出最佳的匹配组合。
---
**[输出结构示例]**
**[输入]**
- GT列表: `[
{{"info_point": "节气是秋分", "importance_score": 5}},
{{"info_point": "农民在收割稻谷", "importance_score": 4}}
]`
- 模型列表: `[
{{"info_point": "这是一个关于秋分的视频", "importance_score": 4}},
{{"info_point": "狗在田里跑", "importance_score": 1}}
]`
**[你的输出]**
```json
{{
"matches": [
{{
"model_wip": {{"info_point": "这是一个关于秋分的视频", "importance_score": 4}},
"gt_wip": {{"info_point": "节气是秋分", "importance_score": 5}}
}}
],
"unmatched_model_wips": [
{{
"info_point": "狗在田里跑",
"importance_score": 1
}}
],
"unmatched_gt_wips": [
{{
"info_point": "农民在收割稻谷",
"importance_score": 4
}}
]
}}
```
现在,请开始你的匹配工作:
[Ground Truth WIPs (GT列表)]
{}
[Model-Generated WIPs (模型列表)]
{}
你的匹配结果 (请严格按照上述要求返回一个格式规整的 JSON,可以被 json.loads 直接解析。请不要在 JSON 数据前后添加任何额外的解释性文字或代码块标记): """
def extract_json_from_response(response: str) -> Optional[Dict]:
"""
Extract JSON from LLM response (simplified version for well-behaved LLMs).
"""
if not response:
return None
try:
response = response.rstrip('```').lstrip('```json')
return json.loads(response.strip())
except json.JSONDecodeError:
print(response)
return None
def extract_wips_single(
text: str,
llm_client
) -> Tuple[Optional[List[Dict]], Optional[str]]:
"""
Extract WIPs from a single text using LLM.
Args:
text: Input text to extract WIPs from
llm_client: LLM client instance (with built-in retry mechanism)
Returns:
Tuple of (wips_list, error_message)
- wips_list: List of WIP dicts if successful, None if failed
- error_message: Error message if failed, None if successful
"""
prompt = WIP_EXTRACTION_PROMPT.format(text)
try:
response = llm_client.generate(prompt)
result = extract_json_from_response(response)
if result is not None and "wips" in result:
return result["wips"], None
return None, "Failed to parse JSON from response"
except Exception as e:
return None, f"API error: {str(e)}"
def extract_wips_batch(
texts: Dict[str, str],
llm_client,
max_workers: int = 5,
desc: str = "Extracting WIPs"
) -> Tuple[Dict[str, List[Dict]], Dict[str, str]]:
"""
Extract WIPs from multiple texts in parallel.
Args:
texts: Dict of {sample_id: text}
llm_client: LLM client instance (with built-in retry mechanism)
max_workers: Number of concurrent workers
desc: Progress bar description
Returns:
Tuple of (results, errors):
- results: Dict of {sample_id: wips_list}
- errors: Dict of {sample_id: error_message}
"""
results = {}
errors = {}
def process_single(sample_id: str, text: str):
wips, error = extract_wips_single(text, llm_client)
return sample_id, wips, error
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(process_single, sid, text): sid
for sid, text in texts.items()
}
for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
sample_id, wips, error = future.result()
if wips:
results[sample_id] = wips
if error:
errors[sample_id] = error
# Statistics: count valid (non-empty) extraction results
total_attempted = len(texts)
total_parsed = len(results)
valid_results = sum(1 for wips in results.values() if wips) # Count non-empty lists
console.print(f"[cyan]{desc} statistics: {total_attempted} attempted, {total_parsed} parsed, {valid_results} valid (non-empty)[/cyan]")
return results, errors
def match_wips_single(
gt_wips: List[Dict],
model_wips: List[Dict],
llm_client
) -> Tuple[Optional[Dict], Optional[str]]:
"""
Match WIPs between ground truth and model generation.
Args:
gt_wips: Ground truth WIPs list
model_wips: Model-generated WIPs list
llm_client: LLM client instance (with built-in retry mechanism)
Returns:
Tuple of (match_result, error_message)
"""
gt_str = json.dumps(gt_wips, ensure_ascii=False, indent=2)
model_str = json.dumps(model_wips, ensure_ascii=False, indent=2)
prompt = WIP_MATCHING_PROMPT.format(gt_str, model_str)
try:
response = llm_client.generate(prompt)
result = extract_json_from_response(response)
if result is not None and all(k in result for k in ["matches", "unmatched_model_wips", "unmatched_gt_wips"]):
return result, None
return None, "Failed to parse match JSON from response"
except Exception as e:
return None, f"API error: {str(e)}"
def match_wips_batch(
gt_wips_dict: Dict[str, List[Dict]],
model_wips_dict: Dict[str, List[Dict]],
llm_client,
max_workers: int = 5
) -> Tuple[Dict[str, Dict], Dict[str, str]]:
"""
Match WIPs for multiple samples in parallel.
Args:
gt_wips_dict: Dict of {sample_id: gt_wips_list}
model_wips_dict: Dict of {sample_id: model_wips_list}
llm_client: LLM client instance (with built-in retry mechanism)
max_workers: Number of concurrent workers
Returns:
Tuple of (results, errors)
"""
results = {}
errors = {}
# Only match samples that have both GT and model WIPs (and both are non-empty)
common_ids = {
id for id in (set(gt_wips_dict.keys()) & set(model_wips_dict.keys()))
if gt_wips_dict[id] and model_wips_dict[id]
}
def process_single(sample_id: str):
gt_wips = gt_wips_dict[sample_id]
model_wips = model_wips_dict[sample_id]
match_result, error = match_wips_single(gt_wips, model_wips, llm_client)
return sample_id, match_result, error
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(process_single, sid): sid
for sid in common_ids
}
for future in tqdm(as_completed(futures), total=len(futures), desc="Matching WIPs"):
sample_id, match_result, error = future.result()
if match_result is not None:
results[sample_id] = match_result
if error is not None:
errors[sample_id] = error
# Statistics: count valid (non-empty) match results
total_attempted = len(common_ids)
total_parsed = len(results)
valid_results = 0
for sample_id, match_result in results.items():
# Check if result is not empty (has at least one non-empty field)
if match_result:
matches = match_result.get("matches", [])
unmatched_model = match_result.get("unmatched_model_wips", [])
unmatched_gt = match_result.get("unmatched_gt_wips", [])
# Consider valid if result has any content
if matches or unmatched_model or unmatched_gt:
valid_results += 1
console.print(f"[cyan]Matching statistics: {total_attempted} attempted, {total_parsed} parsed, {valid_results} valid (non-empty)[/cyan]")
return results, errors
def get_wip_score_int(wip: Optional[Dict]) -> int:
"""Get importance score from WIP, defaulting to 1."""
if not wip:
return 1
return wip.get("importance_score", 1)
def calculate_unweighted_metrics(match_results: Dict[str, Dict], core_threshold: int = 5) -> Dict[str, Any]:
"""
Calculate unweighted metrics (count-based) with macro and per-sample versions.
Args:
match_results: Dict of {sample_id: match_result}
core_threshold: Threshold for core WIPs (importance_score >= threshold)
Returns:
Dict with macro F1, core versions, and per-sample F1s (unweighted)
"""
if not match_results:
return {}
# Per-sample metrics (for macro calculation)
per_sample = {}
for sample_id, result in match_results.items():
if not result:
per_sample[sample_id] = {"overall_f1": 0.0, "core_f1": 0.0}
continue
# Sample-level counts
sample_tp = len(result.get("matches", []))
sample_fp = len(result.get("unmatched_model_wips", []))
sample_fn = len(result.get("unmatched_gt_wips", []))
sample_core_tp = 0
sample_core_fp = 0
sample_core_fn = 0
# Core: count only WIPs with importance_score >= threshold
for match in result.get("matches", []):
gt_wip = match.get("gt_wip", {})
if get_wip_score_int(gt_wip) >= core_threshold:
sample_core_tp += 1
for fp_wip in result.get("unmatched_model_wips", []):
if get_wip_score_int(fp_wip) >= core_threshold:
sample_core_fp += 1
for fn_wip in result.get("unmatched_gt_wips", []):
if get_wip_score_int(fn_wip) >= core_threshold:
sample_core_fn += 1
# Calculate per-sample F1s
sample_overall_f1 = 2 * sample_tp / (2 * sample_tp + sample_fp + sample_fn) if (2 * sample_tp + sample_fp + sample_fn) > 0 else 0.0
sample_core_f1 = 2 * sample_core_tp / (2 * sample_core_tp + sample_core_fp + sample_core_fn) if (2 * sample_core_tp + sample_core_fp + sample_core_fn) > 0 else 0.0
per_sample[sample_id] = {
"overall_f1": sample_overall_f1,
"core_f1": sample_core_f1,
}
# Calculate macro F1 (average of per-sample F1s)
valid_samples = [v for v in per_sample.values() if v]
macro_f1 = sum(s["overall_f1"] for s in valid_samples) / len(valid_samples) if valid_samples else 0.0
macro_core_f1 = sum(s["core_f1"] for s in valid_samples) / len(valid_samples) if valid_samples else 0.0
return {
"macro_wip_unweighted_f1": macro_f1,
"macro_wip_unweighted_core_f1": macro_core_f1,
"per_sample": per_sample,
}
def calculate_importance_weighted_metrics(
match_results: Dict[str, Dict],
core_threshold: int = 5
) -> Dict[str, Any]:
"""
Calculate importance-weighted metrics (weighted by importance_score only) with macro and per-sample versions.
Args:
match_results: Dict of {sample_id: match_result}
core_threshold: Threshold for core WIPs (importance_score >= threshold)
Returns:
Dict with macro F1, core versions, and per-sample F1s (importance-weighted)
"""
if not match_results:
return {}
# Per-sample metrics (for macro calculation)
per_sample = {}
for sample_id, result in match_results.items():
if not result:
per_sample[sample_id] = {"overall_f1": 0.0, "core_f1": 0.0}
continue
# Sample-level metrics
sample_tp, sample_fp, sample_fn = 0.0, 0.0, 0.0
sample_core_tp, sam
gitextract_yy4wnsy3/
├── .gitignore
├── README.md
├── benchmarks/
│ ├── LICENSE
│ ├── README.md
│ ├── api/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── claude.py
│ │ ├── config/
│ │ │ └── llm_config.json
│ │ ├── deepseek.py
│ │ ├── example.py
│ │ └── gemini.py
│ ├── benchmark/
│ │ ├── __init__.py
│ │ ├── base_generator.py
│ │ ├── benchmark.py
│ │ ├── checkpoint_utils.py
│ │ ├── console.py
│ │ ├── generation_runner.py
│ │ ├── gpu_utils.py
│ │ └── tasks/
│ │ ├── __init__.py
│ │ ├── tasks.py
│ │ └── v1_0/
│ │ ├── __init__.py
│ │ ├── base_evaluator.py
│ │ ├── base_loader.py
│ │ ├── item_understand/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── evaluator.py
│ │ │ └── utils.py
│ │ ├── label_pred/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── evaluator.py
│ │ │ └── utils.py
│ │ ├── mfu_evaluator.py
│ │ ├── qwen3.jinja2
│ │ ├── qwen3_soft_switch.jinja2
│ │ ├── rec_reason/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── evaluator.py
│ │ │ └── utils.py
│ │ ├── recommendation/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── evaluator.py
│ │ │ ├── utils.py
│ │ │ └── utils_by_pid.py
│ │ └── registry.py
│ ├── eval_script.sh
│ ├── pyproject.toml
│ ├── requirements.txt
│ └── scripts/
│ ├── __init__.py
│ ├── eval_dev_results.py
│ ├── init_ray.sh
│ ├── init_ray_cluster.sh
│ └── ray-vllm/
│ ├── evaluate.py
│ └── utils/
│ ├── __init__.py
│ ├── arguments.py
│ └── generator.py
├── data/
│ ├── README.md
│ ├── general_text/
│ │ ├── pretrain.csv
│ │ └── sft.csv
│ ├── onerec_data/
│ │ ├── README.md
│ │ ├── pretrain/
│ │ │ ├── item_understand.py
│ │ │ ├── user_profile.py
│ │ │ └── video_rec.py
│ │ ├── run.sh
│ │ └── sft/
│ │ ├── ad_rec.py
│ │ ├── interactive_rec.py
│ │ ├── item_understand.py
│ │ ├── label_cond_rec.py
│ │ ├── label_pred.py
│ │ ├── product_rec.py
│ │ ├── rec_reason.py
│ │ └── video_rec.py
│ ├── prepare_distillation.sh
│ ├── prepare_pretrain.sh
│ ├── prepare_rl.sh
│ ├── prepare_sft.sh
│ └── scripts/
│ ├── parquet_unicode_fix.py
│ ├── sample_data.py
│ ├── split_data.py
│ └── train_test_split.py
├── pretrain/
│ ├── .gitignore
│ ├── README.md
│ ├── examples/
│ │ ├── dataset_config/
│ │ │ ├── pretrain.json
│ │ │ └── sft.json
│ │ ├── posttrain_sft.sh
│ │ ├── pretrain_stg1.sh
│ │ └── pretrain_stg2.sh
│ ├── onerec_llm/
│ │ ├── __init__.py
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── dataloaders.py
│ │ │ ├── local_shuffle_buffer.py
│ │ │ └── qwen3_dataset.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ └── ce.py
│ │ ├── models/
│ │ │ └── qwen3/
│ │ │ ├── __init__.py
│ │ │ ├── configuration_qwen3.py
│ │ │ ├── modeling_qwen3.py
│ │ │ └── modular_qwen3.py
│ │ ├── training/
│ │ │ ├── __init__.py
│ │ │ ├── activations.py
│ │ │ ├── checkpoint.py
│ │ │ ├── common.py
│ │ │ ├── distributed.py
│ │ │ ├── gradients.py
│ │ │ └── lr_schedulers.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── data_utils.py
│ │ ├── distributed.py
│ │ ├── ds_utils.py
│ │ ├── mfu_stats.py
│ │ ├── time_tracker.py
│ │ └── worker_utils.py
│ ├── recipes/
│ │ └── train_qwen3.py
│ ├── scripts/
│ │ ├── convert_checkpoint_to_hf.sh
│ │ ├── expand_qwen3_vocab.sh
│ │ ├── killall.sh
│ │ ├── numa_runner.sh
│ │ ├── test_cases_example.json
│ │ └── test_hf_model.sh
│ ├── set_env.sh
│ ├── tests/
│ │ └── test_qwen3_dataset_file_distribution.py
│ └── tools/
│ ├── model_converter/
│ │ ├── convert_checkpoint_to_hf.py
│ │ └── expand_qwen3_vocab.py
│ └── model_test/
│ └── test_hf_model.py
├── tokenizer/
│ ├── README.md
│ ├── infer_res_kmeans.py
│ ├── res_kmeans.py
│ └── train_res_kmeans.py
├── verl_distillation/
│ ├── LICENSE
│ ├── README.md
│ ├── README_ORIGINAL.md
│ ├── deploy_env.sh
│ ├── docker/
│ │ ├── Apptainerfile.rocm
│ │ ├── Dockerfile.extention.awsefa
│ │ ├── Dockerfile.ngc.vllm
│ │ ├── Dockerfile.ngc.vllm0.8
│ │ ├── Dockerfile.ngc.vllm0.8.sagemaker
│ │ ├── Dockerfile.rocm
│ │ ├── Dockerfile.rocm7
│ │ ├── Dockerfile.rocm_verl-0.3.0.post1
│ │ ├── Dockerfile.rocm_verl-0.4.1
│ │ ├── Dockerfile.sglang
│ │ ├── Dockerfile.vemlp.vllm.te
│ │ ├── Dockerfile.vllm.sglang.megatron.deepseek
│ │ ├── README.md
│ │ ├── ascend/
│ │ │ ├── Dockerfile.ascend_8.2.rc1_a2
│ │ │ └── Dockerfile.ascend_8.2.rc1_a3
│ │ ├── verl0.4-cu124-torch2.6-fa2.7.4/
│ │ │ ├── Dockerfile.app.sglang.vllm.mcore0.12
│ │ │ ├── Dockerfile.app.sglang.vllm.mcore0.12.deepep
│ │ │ ├── Dockerfile.app.sglang.vllm.mcore0.13.preview
│ │ │ ├── Dockerfile.app.vllm.mcore0.12
│ │ │ ├── Dockerfile.app.vllm.mcore0.12.deepep
│ │ │ ├── Dockerfile.app.vllm.mcore0.13.preview
│ │ │ ├── Dockerfile.base
│ │ │ └── README.md
│ │ ├── verl0.5-cu126-torch2.7-fa2.7.4/
│ │ │ ├── Dockerfile.app.sglang0.4.10.post2.mcore0.13
│ │ │ ├── Dockerfile.app.sglang0.4.9.post6.mcore0.13
│ │ │ ├── Dockerfile.app.vllm.mcore0.13
│ │ │ ├── Dockerfile.app.vllm.mcore0.15
│ │ │ ├── Dockerfile.base.torch2.7.1
│ │ │ └── README.md
│ │ ├── verl0.5-cu126-torch2.7.1-fa2.8.0/
│ │ │ ├── Dockerfile.app.sglang.mcore0.12
│ │ │ ├── Dockerfile.app.sglang.mcore0.13.preview
│ │ │ ├── Dockerfile.base
│ │ │ └── README.md
│ │ ├── verl0.5-preview-cu128-torch2.7.1-fa2.8.0/
│ │ │ ├── Dockerfile.app.sglang.megatron
│ │ │ ├── Dockerfile.base
│ │ │ └── README.md
│ │ └── verl0.6-cu128-torch2.8.0-fa2.7.4/
│ │ ├── Dockerfile.app.sglang
│ │ ├── Dockerfile.base
│ │ └── Dockerfile.vllm011.mcore_gpt-oss
│ ├── docs/
│ │ ├── Makefile
│ │ ├── README.md
│ │ ├── README_vllm0.7.md
│ │ ├── README_vllm0.8.md
│ │ ├── _static/
│ │ │ ├── custom.css
│ │ │ └── js/
│ │ │ ├── resizable-sidebar.js
│ │ │ └── runllm-widget.js
│ │ ├── advance/
│ │ │ ├── agent_loop.rst
│ │ │ ├── attention_implementation.rst
│ │ │ ├── checkpoint.rst
│ │ │ ├── dpo_extension.rst
│ │ │ ├── fsdp_extension.rst
│ │ │ ├── fully_async.md
│ │ │ ├── megatron_extension.rst
│ │ │ ├── one_step_off.md
│ │ │ ├── placement.rst
│ │ │ ├── ppo_lora.rst
│ │ │ ├── reward_loop.rst
│ │ │ ├── rollout_is.md
│ │ │ ├── rollout_skip.rst
│ │ │ ├── rollout_trace.rst
│ │ │ └── rope.rst
│ │ ├── algo/
│ │ │ ├── baseline.md
│ │ │ ├── collabllm.md
│ │ │ ├── dapo.md
│ │ │ ├── entropy.md
│ │ │ ├── gpg.md
│ │ │ ├── grpo.md
│ │ │ ├── opo.md
│ │ │ ├── ppo.md
│ │ │ ├── spin.md
│ │ │ └── sppo.md
│ │ ├── amd_tutorial/
│ │ │ ├── amd_build_dockerfile_page.rst
│ │ │ └── amd_vllm_page.rst
│ │ ├── api/
│ │ │ ├── data.rst
│ │ │ ├── single_controller.rst
│ │ │ ├── trainer.rst
│ │ │ └── utils.rst
│ │ ├── ascend_tutorial/
│ │ │ ├── ascend_profiling_en.rst
│ │ │ ├── ascend_profiling_zh.rst
│ │ │ ├── ascend_quick_start.rst
│ │ │ ├── ascend_sglang_quick_start.rst
│ │ │ └── dockerfile_build_guidance.rst
│ │ ├── conf.py
│ │ ├── data/
│ │ │ └── transfer_queue.md
│ │ ├── examples/
│ │ │ ├── config.rst
│ │ │ ├── gsm8k_example.rst
│ │ │ ├── multi_modal_example.rst
│ │ │ ├── ppo_code_architecture.rst
│ │ │ ├── sandbox_fusion_example.rst
│ │ │ └── skypilot_examples.rst
│ │ ├── faq/
│ │ │ └── faq.rst
│ │ ├── hybrid_flow.rst
│ │ ├── index.rst
│ │ ├── perf/
│ │ │ ├── best_practices.rst
│ │ │ ├── device_tuning.rst
│ │ │ ├── dpsk.md
│ │ │ ├── nsight_profiling.md
│ │ │ ├── perf_tuning.rst
│ │ │ └── verl_profiler_system.md
│ │ ├── preparation/
│ │ │ ├── prepare_data.rst
│ │ │ └── reward_function.rst
│ │ ├── requirements-docs.txt
│ │ ├── sglang_multiturn/
│ │ │ ├── interaction_system.rst
│ │ │ ├── multiturn.rst
│ │ │ ├── sandbox_fusion.rst
│ │ │ └── search_tool_example.rst
│ │ ├── single_controller.rst
│ │ ├── start/
│ │ │ ├── agentic_rl.rst
│ │ │ ├── install.rst
│ │ │ ├── more_resources.rst
│ │ │ ├── multinode.rst
│ │ │ ├── quickstart.rst
│ │ │ └── ray_debug_tutorial.rst
│ │ └── workers/
│ │ ├── fsdp_workers.rst
│ │ ├── megatron_workers.rst
│ │ ├── model_engine.rst
│ │ ├── ray_trainer.rst
│ │ └── sglang_worker.rst
│ ├── examples/
│ │ ├── data_preprocess/
│ │ │ ├── aime2024_multiturn_w_tool.py
│ │ │ ├── dapo_multiturn_w_tool.py
│ │ │ ├── full_hh_rlhf.py
│ │ │ ├── geo3k.py
│ │ │ ├── geo3k_multiturn_w_tool.py
│ │ │ ├── gsm8k.py
│ │ │ ├── gsm8k_multiturn_sft.py
│ │ │ ├── gsm8k_multiturn_w_interaction.py
│ │ │ ├── gsm8k_multiturn_w_tool.py
│ │ │ ├── gsm8k_tool_agent_loop.py
│ │ │ ├── hellaswag.py
│ │ │ ├── math_dataset.py
│ │ │ ├── multiturn.py
│ │ │ └── preprocess_search_r1_dataset.py
│ │ ├── generation/
│ │ │ ├── run_deepseek7b_mutli_node.sh
│ │ │ └── run_deepseek_v2_lite_math.sh
│ │ ├── gmpo_trainer/
│ │ │ ├── README.md
│ │ │ ├── run_qwen2_5-7b_math.sh
│ │ │ ├── test_dapo_7b_math.sh
│ │ │ └── test_dapo_qwen3_30b_math.sh
│ │ ├── gpg_trainer/
│ │ │ ├── gpg.md
│ │ │ ├── run_qwen2-7b_math.sh
│ │ │ └── run_qwen2-7b_math_megatron.sh
│ │ ├── grpo_trainer/
│ │ │ ├── README.md
│ │ │ ├── run_deepseek671b_math_megatron_80gb.sh
│ │ │ ├── run_deepseek671b_math_megatron_96gb.sh
│ │ │ ├── run_deepseek7b_llm.sh
│ │ │ ├── run_deepseek7b_llm_math.sh
│ │ │ ├── run_deepseek7b_llm_math_megatron.sh
│ │ │ ├── run_deepseek7b_llm_seq_balance.sh
│ │ │ ├── run_glm41v_9b.sh
│ │ │ ├── run_gptoss_20b.sh
│ │ │ ├── run_minicpmo2_6.sh
│ │ │ ├── run_mistral13b_skyworkrm_hhrlhf.sh
│ │ │ ├── run_moonlight16b_math_megatron.sh
│ │ │ ├── run_qwen2-7b.sh
│ │ │ ├── run_qwen2-7b_math.sh
│ │ │ ├── run_qwen2-7b_math_megatron.sh
│ │ │ ├── run_qwen2-7b_seq_balance.sh
│ │ │ ├── run_qwen2-7b_seq_balance_math_megatron.sh
│ │ │ ├── run_qwen2-7b_sgl_megatron.sh
│ │ │ ├── run_qwen2_5-3b_gsm8k_grpo_lora.sh
│ │ │ ├── run_qwen2_5-3b_gsm8k_grpo_lora_from_adapter.sh
│ │ │ ├── run_qwen2_5-7b_math_megatron_diff_tp.sh
│ │ │ ├── run_qwen2_5_32b_grpo_npu.sh
│ │ │ ├── run_qwen2_5_7b_grpo_discrete_prof_npu.sh
│ │ │ ├── run_qwen2_5_7b_grpo_e2e_prof_npu.sh
│ │ │ ├── run_qwen2_5_7b_grpo_npu.sh
│ │ │ ├── run_qwen2_5_vl-7b-megatron.sh
│ │ │ ├── run_qwen2_5_vl-7b-sglang.sh
│ │ │ ├── run_qwen2_5_vl-7b.sh
│ │ │ ├── run_qwen2_5_vl-7b_freeze_vision.sh
│ │ │ ├── run_qwen2_5_vl-7b_lora.sh
│ │ │ ├── run_qwen2_5_vl-7b_seq_balance.sh
│ │ │ ├── run_qwen2_5_vl_32b_npu.sh
│ │ │ ├── run_qwen2_5_vl_3b_npu.sh
│ │ │ ├── run_qwen2_5_vl_7b_npu.sh
│ │ │ ├── run_qwen3-235b_megatron_96gb.sh
│ │ │ ├── run_qwen3-32b_npu.sh
│ │ │ ├── run_qwen3-8b.sh
│ │ │ ├── run_qwen3-8b_npu.sh
│ │ │ ├── run_qwen3_8b_grpo_sglang_1k_spmd_npu.sh
│ │ │ ├── run_qwen3_8b_grpo_sglang_32k_spmd_npu.sh
│ │ │ ├── run_qwen3_vl-235b-megatron.sh
│ │ │ ├── run_qwen3_vl-30b-megatron.sh
│ │ │ ├── run_qwen3_vl-8b-megatron.sh
│ │ │ ├── run_qwen3moe-30b_megatron_96gb.sh
│ │ │ └── run_seed_oss_36b.sh
│ │ ├── ppo_trainer/
│ │ │ ├── README.md
│ │ │ ├── run_deepseek7b_llm.sh
│ │ │ ├── run_deepseek7b_llm_modelscope.sh
│ │ │ ├── run_deepseek7b_llm_pfppo.sh
│ │ │ ├── run_deepseek7b_llm_sandbox_fusion.sh
│ │ │ ├── run_deepseek7b_llm_sp2.sh
│ │ │ ├── run_deepseek_full_hh_rlhf.sh
│ │ │ ├── run_deepseek_math_gsm8k_megatron.sh
│ │ │ ├── run_deepseek_math_gsm8k_megatron_nsys.sh
│ │ │ ├── run_gemma.sh
│ │ │ ├── run_moonlight16b_a3b_gsm8k_megatron.sh
│ │ │ ├── run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh
│ │ │ ├── run_qwen2-7b_math_gsm8k_megatron.sh
│ │ │ ├── run_qwen2-7b_rm.sh
│ │ │ ├── run_qwen2-7b_rm_seq_balance.sh
│ │ │ ├── run_qwen2-7b_rm_seq_balance_fused_kernels.sh
│ │ │ ├── run_qwen2-7b_rm_seq_balance_nsys.sh
│ │ │ ├── run_qwen2-7b_seq_balance.sh
│ │ │ ├── run_qwen2-7b_sglang_seq_balance.sh
│ │ │ ├── run_qwen2.5-32b.sh
│ │ │ └── run_qwen3-8b_npu.sh
│ │ ├── ray/
│ │ │ └── tutorial.ipynb
│ │ ├── reinforce_plus_plus_trainer/
│ │ │ ├── run_qwen2-7b_math_rf.sh
│ │ │ └── run_qwen2-7b_math_rf_baseline.sh
│ │ ├── remax_trainer/
│ │ │ ├── run_qwen2.5-3b_seq_balance.sh
│ │ │ └── run_qwen2.5-7b_seq_balance.sh
│ │ ├── rloo_trainer/
│ │ │ └── run_qwen2-7b.sh
│ │ ├── rollout_importance_sampling/
│ │ │ ├── README.md
│ │ │ └── run_with_rollout_is.sh
│ │ ├── sft/
│ │ │ ├── gsm8k/
│ │ │ │ ├── run_deepseek_6b7.sh
│ │ │ │ ├── run_gemma_2b.sh
│ │ │ │ ├── run_gemma_7b.sh
│ │ │ │ ├── run_qwen3_8b_sft_peft_sp2_npu.sh
│ │ │ │ ├── run_qwen_05_peft.sh
│ │ │ │ ├── run_qwen_05_sp2.sh
│ │ │ │ ├── run_qwen_05_sp2_liger.sh
│ │ │ │ └── run_seed_oss_36b_sft.sh
│ │ │ └── multiturn/
│ │ │ └── run_qwen_05_sp2.sh
│ │ ├── sglang_multiturn/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ ├── geo3k_multiturn_grpo.yaml
│ │ │ │ ├── geo3k_multiturn_megatron_grpo.yaml
│ │ │ │ ├── gsm8k_multiturn_grpo.yaml
│ │ │ │ ├── gsm8k_multiturn_grpo_server.yaml
│ │ │ │ ├── gsm8k_multiturn_grpo_w_interaction.yaml
│ │ │ │ ├── gsm8k_multiturn_megatron_grpo.yaml
│ │ │ │ ├── interaction_config/
│ │ │ │ │ └── gsm8k_interaction_config.yaml
│ │ │ │ ├── retool_multiturn_grpo.yaml
│ │ │ │ ├── search_multiturn_grpo.yaml
│ │ │ │ ├── search_multiturn_grpo_one_step_off.yaml
│ │ │ │ └── tool_config/
│ │ │ │ ├── geo3k_tool_config.yaml
│ │ │ │ ├── gsm8k_tool_config.yaml
│ │ │ │ ├── mcp_server.json
│ │ │ │ ├── mcp_tool_config.yaml
│ │ │ │ ├── sandbox_fusion_tool_config.yaml
│ │ │ │ └── search_tool_config.yaml
│ │ │ ├── geo3k/
│ │ │ │ ├── run_qwen2.5-3b_geo3k_multiturn.sh
│ │ │ │ ├── run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh
│ │ │ │ └── run_qwen2.5-3b_megatron_geo3k_multiturn.sh
│ │ │ ├── run_qwen0.5b_gsm8k_multiturn_curriculum.sh
│ │ │ ├── run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn_4xgpu_server.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn_server.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_multiturn_vllm_fsdp.sh
│ │ │ ├── run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh
│ │ │ ├── run_qwen2.5-3b_megatron_gsm8k_multiturn.sh
│ │ │ ├── run_qwen3-4b_gsm8k_multiturn.sh
│ │ │ ├── run_qwen3_4b_dapo_multiturn.sh
│ │ │ └── search_r1_like/
│ │ │ ├── local_dense_retriever/
│ │ │ │ ├── download.py
│ │ │ │ └── retrieval_server.py
│ │ │ └── run_qwen2.5-3b_instruct_search_multiturn.sh
│ │ ├── skypilot/
│ │ │ ├── README.md
│ │ │ ├── verl-grpo.yaml
│ │ │ ├── verl-multiturn-tools.yaml
│ │ │ └── verl-ppo.yaml
│ │ ├── slurm/
│ │ │ └── ray_on_slurm.slurm
│ │ ├── split_placement/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ └── ppo_trainer_split.yaml
│ │ │ ├── main_ppo_split.py
│ │ │ ├── run_deepseek7b_llm.sh
│ │ │ └── split_monkey_patch.py
│ │ ├── tuning/
│ │ │ ├── 0.5b/
│ │ │ │ └── qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh
│ │ │ ├── 1.5b/
│ │ │ │ └── qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh
│ │ │ ├── 14b/
│ │ │ │ ├── qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh
│ │ │ │ └── qwen2_14b_grpo_4_h800_fsdp_vllm.sh
│ │ │ ├── 32b/
│ │ │ │ ├── qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh
│ │ │ │ └── qwen2_32B_grpo_8_h20_megatron_vllm.sh
│ │ │ ├── 3b/
│ │ │ │ └── qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh
│ │ │ ├── 70b/
│ │ │ │ ├── qwen2-70b_grpo_32_h20_fsdp_vllm.sh
│ │ │ │ ├── qwen2-70b_grpo_32_h800_fsdp_vllm.sh
│ │ │ │ └── qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh
│ │ │ └── 7b/
│ │ │ ├── qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh
│ │ │ └── qwen2-7b_grpo_2_h800_fsdp_vllm.sh
│ │ └── tutorial/
│ │ └── agent_loop_get_started/
│ │ ├── agent_loop_tutorial.ipynb
│ │ └── sandbox.py
│ ├── init_ray.sh
│ ├── init_ray_cluster.sh
│ ├── pyproject.toml
│ ├── recipe/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── char_count/
│ │ │ ├── README.md
│ │ │ ├── create_dataset.py
│ │ │ ├── reward_function.py
│ │ │ ├── train_grpo.sh
│ │ │ └── train_sft.sh
│ │ ├── collabllm/
│ │ │ ├── README.md
│ │ │ ├── collabllm_agent_loop.py
│ │ │ ├── collabllm_interation.py
│ │ │ ├── config/
│ │ │ │ ├── agent.yaml
│ │ │ │ └── collabllm_interaction_config.yaml
│ │ │ ├── metrics/
│ │ │ │ ├── accuracy.py
│ │ │ │ ├── bleu_score.py
│ │ │ │ ├── interactivity.py
│ │ │ │ ├── pass_rate.py
│ │ │ │ └── token_amount.py
│ │ │ ├── process_dataset.py
│ │ │ ├── reward_function.py
│ │ │ ├── train_rl_collabllm.sh
│ │ │ ├── train_sft_collabllm.sh
│ │ │ └── utils.py
│ │ ├── dapo/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ ├── dapo_megatron_trainer.yaml
│ │ │ │ └── dapo_trainer.yaml
│ │ │ ├── dapo_ray_trainer.py
│ │ │ ├── main_dapo.py
│ │ │ ├── prepare_dapo_data.sh
│ │ │ ├── run_dapo_early_qwen2.5_32b.sh
│ │ │ ├── run_dapo_qwen2.5_32b.sh
│ │ │ ├── run_dapo_qwen2.5_32b_npu.sh
│ │ │ ├── run_dapo_qwen2.5_32b_rollout_is.sh
│ │ │ ├── run_dapo_qwen2.5_7b_npu.sh
│ │ │ ├── run_dapo_qwen3_14b_base_npu.sh
│ │ │ ├── run_dapo_qwen3_8b_base_npu.sh
│ │ │ ├── run_dapo_qwen3_moe_30b_base_fsdp_npu.sh
│ │ │ ├── run_dapo_qwen3_moe_30b_megatron_npu.sh
│ │ │ ├── run_dapo_wo_ds_qwen2.5_32b.sh
│ │ │ ├── runtime_env.yaml
│ │ │ ├── test_dapo_7b.sh
│ │ │ ├── test_dapo_7b_math.sh
│ │ │ ├── test_dapo_7b_math_lora.sh
│ │ │ ├── test_dapo_7b_math_megatron.sh
│ │ │ ├── test_dapo_dspk_671b_megatron_96gb.sh
│ │ │ ├── test_dapo_glm_air_megatron.sh
│ │ │ ├── test_dapo_qwen3_30b_math.sh
│ │ │ └── test_dapo_qwen3_30b_math_single_node.sh
│ │ ├── deepeyes/
│ │ │ ├── README.md
│ │ │ ├── configs/
│ │ │ │ ├── deepeyes_multiturn_grpo.yaml
│ │ │ │ └── image_zoom_in_tool_config.yaml
│ │ │ ├── deepeyes.py
│ │ │ └── run_deepeyes_grpo.sh
│ │ ├── entropy/
│ │ │ ├── 32b_clip_cov.sh
│ │ │ ├── 32b_kl_cov.sh
│ │ │ ├── 32b_kl_cov_mininbsz.sh
│ │ │ ├── 7b_clip_cov.sh
│ │ │ ├── 7b_kl_cov.sh
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ └── entropy_trainer.yaml
│ │ │ ├── entropy_ray_trainer.py
│ │ │ ├── main_entropy.py
│ │ │ ├── reward.py
│ │ │ └── reward_score/
│ │ │ ├── __init__.py
│ │ │ └── entropy_math/
│ │ │ ├── __init__.py
│ │ │ ├── grader.py
│ │ │ └── math_normalize.py
│ │ ├── fapo/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ └── rm_config.yaml
│ │ │ ├── prepare_fapo_data.py
│ │ │ ├── reward_fn_genrm.py
│ │ │ ├── reward_fn_reasoning.py
│ │ │ ├── reward_fn_reasoning_remote.py
│ │ │ ├── run_baseline_32b.sh
│ │ │ ├── run_baseline_7b.sh
│ │ │ ├── run_fapo_32b.sh
│ │ │ ├── run_fapo_32b_remote.sh
│ │ │ ├── run_fapo_7b.sh
│ │ │ ├── run_fapo_7b_remote.sh
│ │ │ ├── run_fapo_genrm_train.sh
│ │ │ └── runtime_env.yaml
│ │ ├── fully_async_policy/
│ │ │ ├── README.md
│ │ │ ├── README_zh.md
│ │ │ ├── agent_loop/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── agent_loop.py
│ │ │ │ └── partial_single_turn_agent_loop.py
│ │ │ ├── config/
│ │ │ │ ├── fully_async_ppo_megatron_trainer.yaml
│ │ │ │ └── fully_async_ppo_trainer.yaml
│ │ │ ├── detach_utils.py
│ │ │ ├── fsdp2_utils.py
│ │ │ ├── fsdp_workers.py
│ │ │ ├── fully_async_main.py
│ │ │ ├── fully_async_rollouter.py
│ │ │ ├── fully_async_trainer.py
│ │ │ ├── megatron_worker.py
│ │ │ ├── message_queue.py
│ │ │ ├── param_sync.py
│ │ │ ├── ray_trainer.py
│ │ │ ├── shell/
│ │ │ │ ├── dapo_7b_math_fsdp2_16_16.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_32_32.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_4_12.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_4_4.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_64_64.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_64_64_mis.sh
│ │ │ │ ├── dapo_7b_math_fsdp2_8_8.sh
│ │ │ │ ├── geo3k_qwen25vl_7b_megatron_4_4.sh
│ │ │ │ └── runtime_env.yaml
│ │ │ ├── unittest/
│ │ │ │ └── simple_streaming_demo.py
│ │ │ └── vllm_rollout/
│ │ │ ├── __init__.py
│ │ │ └── vllm_async_server.py
│ │ ├── genrm_remote/
│ │ │ ├── README.md
│ │ │ ├── reward_function.py
│ │ │ └── run_genrm_remote.sh
│ │ ├── gspo/
│ │ │ ├── test_gspo_3b_math.sh
│ │ │ ├── test_gspo_3b_math_slurm.sh
│ │ │ └── test_gspo_qwen30b_a3b_ep.sh
│ │ ├── infigui-g1/
│ │ │ ├── README.md
│ │ │ ├── reward_fn.py
│ │ │ ├── run_3b.sh
│ │ │ └── run_7b.sh
│ │ ├── langgraph_agent/
│ │ │ ├── __init__.py
│ │ │ ├── chat_model.py
│ │ │ ├── example/
│ │ │ │ ├── README.md
│ │ │ │ ├── agent.yaml
│ │ │ │ ├── create_dataset.py
│ │ │ │ ├── math_expression.py
│ │ │ │ ├── run_gpt_oss_20b_bf16.sh
│ │ │ │ └── run_qwen2.5_3b.sh
│ │ │ ├── react_agent_loop.py
│ │ │ └── test_react_agent_loop.py
│ │ ├── minicpmo/
│ │ │ └── rl_dataset.py
│ │ ├── one_step_off_policy/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ ├── one_step_off_ppo_megatron_trainer.yaml
│ │ │ │ └── one_step_off_ppo_trainer.yaml
│ │ │ ├── dapo_7b_math_fsdp2_4_12.sh
│ │ │ ├── dapo_7b_math_fsdp2_colocate.sh
│ │ │ ├── dapo_7b_math_fsdp2_sglang_4_12.sh
│ │ │ ├── dapo_7b_math_fsdp2_sglang_colocate.sh
│ │ │ ├── dapo_7b_math_megatron_4_12.sh
│ │ │ ├── dapo_7b_math_megatron_colocate.sh
│ │ │ ├── distributed_util.py
│ │ │ ├── fsdp_workers.py
│ │ │ ├── grpo_0.6b_gsm8k_fsdp2_2_6.sh
│ │ │ ├── grpo_0.6b_gsm8k_fsdp2_sglang_2_6.sh
│ │ │ ├── grpo_3b_gsm8k_fsdp2_2_6.sh
│ │ │ ├── main_ppo.py
│ │ │ ├── megatron_workers.py
│ │ │ ├── ray_trainer.py
│ │ │ ├── sglang_sharding_manager.py
│ │ │ ├── utils.py
│ │ │ └── vllm_sharding_manager.py
│ │ ├── onpolicy_distill/
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ └── onpolicy_distill_trainer.yaml
│ │ │ ├── main_onpolicy_distill.py
│ │ │ ├── onpolicy_distill_trainer.py
│ │ │ └── run_qwen3_distill.sh
│ │ ├── open_math_reasoning/
│ │ │ ├── README.md
│ │ │ ├── compute_score.py
│ │ │ ├── prepare_eval_dataset.py
│ │ │ ├── prepare_nvidia-OpenMathReasoning_sft.py
│ │ │ ├── run_eval.sh
│ │ │ ├── run_generation.sh
│ │ │ └── run_sft_qwen3_8b.sh
│ │ ├── prime/
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ └── prime_trainer.yaml
│ │ │ ├── main_prime.py
│ │ │ ├── prime_core_algos.py
│ │ │ ├── prime_dp_rm.py
│ │ │ ├── prime_fsdp_workers.py
│ │ │ ├── prime_ray_trainer.py
│ │ │ ├── run_prime_qwen.sh
│ │ │ └── run_prime_qwen_code.sh
│ │ ├── r1/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ └── evaluation.yaml
│ │ │ ├── data_process.py
│ │ │ ├── main_eval.py
│ │ │ ├── reward_score.py
│ │ │ ├── run_r1_distill_qwen.sh
│ │ │ └── tasks/
│ │ │ ├── __init__.py
│ │ │ ├── gpqa.py
│ │ │ ├── livecodebench.py
│ │ │ └── math_reward.py
│ │ ├── retool/
│ │ │ ├── README.md
│ │ │ ├── retool.py
│ │ │ ├── retool_sft_preprocess.py
│ │ │ ├── run_gpt_oss_ppo.sh
│ │ │ ├── run_qwen2-32b_dapo.sh
│ │ │ ├── run_qwen2-32b_ppo.sh
│ │ │ ├── run_qwen2-32b_sft.sh
│ │ │ ├── run_qwen2_7b_dapo.sh
│ │ │ ├── run_qwen2_7b_sft.sh
│ │ │ ├── run_qwen2_7b_sft_npu.sh
│ │ │ └── sandbox_fusion_tool_config.yaml
│ │ ├── spin/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ └── spin_trainer.yaml
│ │ │ ├── core_algos.py
│ │ │ ├── dp_actor.py
│ │ │ ├── fsdp_workers.py
│ │ │ ├── main_spin.py
│ │ │ ├── run_spin.sh
│ │ │ ├── spin_trainer.py
│ │ │ └── utils.py
│ │ ├── sppo/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ └── sppo_trainer.yaml
│ │ │ ├── config.py
│ │ │ ├── dp_actor.py
│ │ │ ├── main_sppo.py
│ │ │ ├── run_qwen2.5-7b_rm.sh
│ │ │ ├── sppo_ray_trainer.py
│ │ │ └── sppo_worker.py
│ │ └── transfer_queue/
│ │ ├── agent_loop.py
│ │ ├── config/
│ │ │ └── transfer_queue_ppo_trainer.yaml
│ │ ├── main_ppo.py
│ │ ├── ray_trainer.py
│ │ └── run_qwen3-8b_transferqueue_npu.sh
│ ├── requirements-cuda.txt
│ ├── requirements-npu.txt
│ ├── requirements.txt
│ ├── requirements_sglang.txt
│ ├── requirements_transferqueue.txt
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── converter_hf_to_mcore.py
│ │ ├── diagnose.py
│ │ ├── generate_trainer_config.sh
│ │ ├── init_random_model.py
│ │ ├── install_vllm_sglang_mcore.sh
│ │ ├── legacy_model_merger.py
│ │ ├── print_cfg.py
│ │ └── rollout_viewer.py
│ ├── setup.py
│ ├── tests/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── experimental/
│ │ │ ├── agent_loop/
│ │ │ │ ├── agent_utils.py
│ │ │ │ ├── qwen_vl_tool_chat_template.jinja2
│ │ │ │ ├── test_agent_loop_reward.py
│ │ │ │ ├── test_agent_loop_reward_model.py
│ │ │ │ ├── test_basic_agent_loop.py
│ │ │ │ ├── test_gpt_oss_tool_parser.py
│ │ │ │ ├── test_multi_modal.py
│ │ │ │ └── test_standalone_rollout.py
│ │ │ └── reward/
│ │ │ ├── reward_fn.py
│ │ │ ├── test_agent_loop_reward_manager.py
│ │ │ └── test_reward_model.py
│ │ ├── interactions/
│ │ │ ├── __init__.py
│ │ │ ├── test_gsm8k_interaction.py
│ │ │ └── test_interaction_registry.py
│ │ ├── kill_github_tests.sh
│ │ ├── models/
│ │ │ ├── test_engine.py
│ │ │ ├── test_transformer.py
│ │ │ └── test_transformers_ulysses.py
│ │ ├── single_controller/
│ │ │ ├── __init__.py
│ │ │ ├── base/
│ │ │ │ └── test_decorator.py
│ │ │ ├── check_worker_alive/
│ │ │ │ └── main.py
│ │ │ ├── detached_worker/
│ │ │ │ ├── README.md
│ │ │ │ ├── client.py
│ │ │ │ ├── run.sh
│ │ │ │ └── server.py
│ │ │ ├── test_auto_padding_on_cpu.py
│ │ │ ├── test_colocated_workers.py
│ │ │ ├── test_colocated_workers_fused.py
│ │ │ ├── test_data_transfer.py
│ │ │ ├── test_decorator_on_cpu.py
│ │ │ ├── test_device_mesh_register.py
│ │ │ ├── test_driverfunc_to_worker.py
│ │ │ ├── test_fused_workers_on_cpu.py
│ │ │ ├── test_high_level_scheduling_api.py
│ │ │ ├── test_nested_worker.py
│ │ │ ├── test_ray_collectives.py
│ │ │ ├── test_ray_local_envs_on_cpu.py
│ │ │ ├── test_ray_utils_on_cpu.py
│ │ │ ├── test_rvdz.py
│ │ │ ├── test_worker_group_basics.py
│ │ │ └── test_worker_group_torch.py
│ │ ├── special_distributed/
│ │ │ ├── README.md
│ │ │ ├── run_all.sh
│ │ │ ├── test_fsdp_ckpt.py
│ │ │ ├── test_mcore_config_converter.py
│ │ │ └── test_tensor_dict.py
│ │ ├── special_e2e/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── check_custom_rwd_fn.py
│ │ │ ├── check_results.py
│ │ │ ├── envs/
│ │ │ │ ├── __init__.py
│ │ │ │ └── digit_completion/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── task.py
│ │ │ │ └── tokenizer.py
│ │ │ ├── generation/
│ │ │ │ ├── run_gen_qwen05.sh
│ │ │ │ └── run_gen_qwen05_server.sh
│ │ │ ├── ppo_trainer/
│ │ │ │ ├── expert_parallel/
│ │ │ │ │ └── qwen2moe_minimal.json
│ │ │ │ ├── run_function_reward.sh
│ │ │ │ ├── run_model_reward.sh
│ │ │ │ ├── run_single_gpu.sh
│ │ │ │ └── run_single_gpu_with_engine.sh
│ │ │ ├── run_dapo.sh
│ │ │ ├── run_fully_async_policy.sh
│ │ │ ├── run_genrm_remote.sh
│ │ │ ├── run_geo3k_fsdp_sgl_multiturn_w_tool.sh
│ │ │ ├── run_grpo_lora_with_merge.sh
│ │ │ ├── run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh
│ │ │ ├── run_gsm8k_fsdp_sgl_multiturn_w_tool.sh
│ │ │ ├── run_one_step_off_policy.sh
│ │ │ ├── run_ppo_trainer_megatron.sh
│ │ │ ├── run_prime.sh
│ │ │ ├── run_r1_distill_qwen_aime24_eval.sh
│ │ │ ├── run_spin.sh
│ │ │ ├── run_sppo.sh
│ │ │ ├── run_test.sh
│ │ │ └── sft/
│ │ │ ├── compare_sft_engine_results.py
│ │ │ ├── run_sft.sh
│ │ │ ├── run_sft_engine_gsm8k.sh
│ │ │ ├── test_sft_engine_all.sh
│ │ │ └── test_sp_loss_match.py
│ │ ├── special_npu/
│ │ │ ├── run_qwen2_5_05b_dapo.sh
│ │ │ ├── run_qwen2_5_05b_grpo.sh
│ │ │ ├── run_qwen2_5_05b_grpo_mindspeed.sh
│ │ │ ├── run_qwen2_5_05b_sft_peft_sp2.sh
│ │ │ ├── run_qwen2_5_vl_3b_npu.sh
│ │ │ └── run_qwen3_06b_ppo.sh
│ │ ├── special_sanity/
│ │ │ ├── check_api_docs.py
│ │ │ ├── check_dataproto_usage.py
│ │ │ ├── check_device_api_usage.py
│ │ │ ├── check_docs_time_info.py
│ │ │ ├── check_docstrings.py
│ │ │ ├── check_license.py
│ │ │ ├── check_pr_description.py
│ │ │ ├── check_pr_title.py
│ │ │ ├── test_config_docs.py
│ │ │ ├── test_import.py
│ │ │ ├── type_coverage_check.py
│ │ │ ├── validate_imported_docs.py
│ │ │ └── validate_structure.py
│ │ ├── special_standalone/
│ │ │ ├── README.md
│ │ │ └── test_memory_buffers.py
│ │ ├── test_base_config_on_cpu.py
│ │ ├── test_protocol_on_cpu.py
│ │ ├── test_protocol_v2_on_cpu.py
│ │ ├── trainer/
│ │ │ ├── __init__.py
│ │ │ ├── config/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── legacy_ppo_megatron_trainer.yaml
│ │ │ │ ├── legacy_ppo_trainer.yaml
│ │ │ │ ├── test_algo_config_on_cpu.py
│ │ │ │ └── test_legacy_config_on_cpu.py
│ │ │ └── ppo/
│ │ │ ├── __init__.py
│ │ │ ├── test_core_algos_on_cpu.py
│ │ │ ├── test_metric_utils_on_cpu.py
│ │ │ ├── test_rollout_is.py
│ │ │ └── test_rollout_is_integration.py
│ │ ├── utils/
│ │ │ ├── _test_module.py
│ │ │ ├── dataset/
│ │ │ │ ├── test_create_rl_sampler_on_cpu.py
│ │ │ │ ├── test_multiturn_sft_dataset_on_cpu.py
│ │ │ │ ├── test_rl_collate_fn_on_cpu.py
│ │ │ │ ├── test_rl_dataset_on_cpu.py
│ │ │ │ └── test_sft_dataset_on_cpu.py
│ │ │ ├── debug/
│ │ │ │ └── test_metrics.py
│ │ │ ├── megatron/
│ │ │ │ └── test_pipeline_parallel.py
│ │ │ ├── reward_score/
│ │ │ │ ├── reward_score/
│ │ │ │ │ └── test_sandbox_fusion_on_cpu.py
│ │ │ │ └── test_sandbox_on_cpu.py
│ │ │ ├── test_activation_offload.py
│ │ │ ├── test_config_on_cpu.py
│ │ │ ├── test_flops_counter.py
│ │ │ ├── test_fs_on_cpu.py
│ │ │ ├── test_groupwise.py
│ │ │ ├── test_import_utils_on_cpu.py
│ │ │ ├── test_linear_cross_entropy.py
│ │ │ ├── test_mlflow_key_sanitization.py
│ │ │ ├── test_model_on_cpu.py
│ │ │ ├── test_nvtx_profile.py
│ │ │ ├── test_rollout_skip_on_cpu.py
│ │ │ ├── test_rollout_trace_on_cpu.py
│ │ │ ├── test_seqlen_balancing.py
│ │ │ ├── test_special_linear_cross_entropy_tp.py
│ │ │ ├── test_special_mstx_profile.py
│ │ │ ├── test_temp_env_on_cpu.py
│ │ │ ├── test_timeout_decorator_cpu.py
│ │ │ └── test_torch_functional.py
│ │ └── workers/
│ │ ├── actor/
│ │ │ └── test_special_dp_actor.py
│ │ ├── config/
│ │ │ ├── test_actor_config_on_cpu.py
│ │ │ ├── test_critic_config_on_cpu.py
│ │ │ ├── test_engine_config_on_cpu.py
│ │ │ └── test_optim_config_on_cpu.py
│ │ ├── critic/
│ │ │ └── test_special_dp_critic.py
│ │ ├── reward_manager/
│ │ │ └── test_registry_on_cpu.py
│ │ ├── rollout/
│ │ │ ├── perf/
│ │ │ │ └── vllm_async_rollout.py
│ │ │ ├── resource/
│ │ │ │ └── tool_configs/
│ │ │ │ ├── mcp_server.json
│ │ │ │ ├── mcp_tool_config
│ │ │ │ ├── sandbox_fusion_tool_config
│ │ │ │ └── search_tool_config
│ │ │ ├── rollout_sglang/
│ │ │ │ └── test_http_server_engine.py
│ │ │ ├── rollout_vllm/
│ │ │ │ ├── run_fsdp_vllm.py
│ │ │ │ ├── test_vllm_model_rope_scaling.py
│ │ │ │ └── test_vllm_spmd.py
│ │ │ ├── test_hf_rollout.py
│ │ │ ├── test_sglang_async_rollout_mcp_tools.py
│ │ │ ├── test_sglang_async_rollout_multimodal_delta.py
│ │ │ ├── test_sglang_async_rollout_search_tools.py
│ │ │ ├── test_sglang_async_rollout_sf_tools.py
│ │ │ ├── test_sglang_async_rollout_w_interaction.py
│ │ │ ├── test_sglang_async_rollout_w_tools.py
│ │ │ ├── test_sglang_async_rollout_w_tools_token_out.py
│ │ │ ├── test_sglang_multi_interaction.py
│ │ │ ├── test_sglang_rollout_sharding_manager.py
│ │ │ ├── test_sglang_spmd.py
│ │ │ └── utils_sglang.py
│ │ ├── test_fsdp_attn_implementation.py
│ │ └── test_fsdp_workers.py
│ └── verl/
│ ├── __init__.py
│ ├── base_config.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ ├── agent_loop/
│ │ │ ├── __init__.py
│ │ │ ├── agent_loop.py
│ │ │ ├── single_turn_agent_loop.py
│ │ │ ├── tool_agent_loop.py
│ │ │ ├── tool_parser.py
│ │ │ └── utils.py
│ │ ├── dataset/
│ │ │ ├── __init__.py
│ │ │ └── sampler.py
│ │ ├── dynamic_dataset/
│ │ │ ├── __init__.py
│ │ │ └── dynamicgen_dataset.py
│ │ └── reward/
│ │ ├── __init__.py
│ │ ├── reward_loop/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── dapo.py
│ │ │ ├── naive.py
│ │ │ └── registry.py
│ │ ├── reward_manager.py
│ │ ├── reward_model.py
│ │ └── router/
│ │ ├── naive_router.py
│ │ └── sglang_router.py
│ ├── interactions/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── gsm8k_interaction.py
│ │ ├── utils/
│ │ │ ├── __init__.py
│ │ │ └── interaction_registry.py
│ │ └── weather_interaction.py
│ ├── model_merger/
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── base_model_merger.py
│ │ ├── fsdp_model_merger.py
│ │ └── megatron_model_merger.py
│ ├── models/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── llama/
│ │ │ ├── __init__.py
│ │ │ └── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── llama_loader.py
│ │ │ │ ├── llama_loader_depracated.py
│ │ │ │ └── llama_saver.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── parallel_attention.py
│ │ │ │ ├── parallel_decoder.py
│ │ │ │ ├── parallel_linear.py
│ │ │ │ ├── parallel_mlp.py
│ │ │ │ └── parallel_rmsnorm.py
│ │ │ └── modeling_llama_megatron.py
│ │ ├── mcore/
│ │ │ ├── __init__.py
│ │ │ ├── config_converter.py
│ │ │ ├── loader.py
│ │ │ ├── mbridge.py
│ │ │ ├── model_forward.py
│ │ │ ├── model_forward_1f1b_overlap.py
│ │ │ ├── model_forward_fused.py
│ │ │ ├── model_initializer.py
│ │ │ ├── patch_v012.py
│ │ │ ├── qwen2_5_vl/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── model.py
│ │ │ │ ├── rope_utils.py
│ │ │ │ ├── vision_config.py
│ │ │ │ ├── vision_model.py
│ │ │ │ └── vision_transformer_block.py
│ │ │ ├── readme.md
│ │ │ ├── registry.py
│ │ │ ├── saver.py
│ │ │ ├── util.py
│ │ │ └── weight_converter.py
│ │ ├── qwen2/
│ │ │ ├── __init__.py
│ │ │ └── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── qwen2_loader.py
│ │ │ │ ├── qwen2_loader_depracated.py
│ │ │ │ └── qwen2_saver.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── parallel_attention.py
│ │ │ │ ├── parallel_decoder.py
│ │ │ │ ├── parallel_linear.py
│ │ │ │ ├── parallel_mlp.py
│ │ │ │ └── parallel_rmsnorm.py
│ │ │ └── modeling_qwen2_megatron.py
│ │ ├── registry.py
│ │ ├── transformers/
│ │ │ ├── __init__.py
│ │ │ ├── apertus.py
│ │ │ ├── dense_common.py
│ │ │ ├── glm4v.py
│ │ │ ├── kimi_vl.py
│ │ │ ├── llama.py
│ │ │ ├── monkey_patch.py
│ │ │ ├── npu_patch.py
│ │ │ ├── qwen2.py
│ │ │ ├── qwen2_vl.py
│ │ │ └── qwen3_vl.py
│ │ └── weight_loader_registry.py
│ ├── protocol.py
│ ├── py.typed
│ ├── single_controller/
│ │ ├── __init__.py
│ │ ├── base/
│ │ │ ├── __init__.py
│ │ │ ├── decorator.py
│ │ │ ├── worker.py
│ │ │ └── worker_group.py
│ │ └── ray/
│ │ ├── __init__.py
│ │ └── base.py
│ ├── third_party/
│ │ ├── __init__.py
│ │ ├── sglang/
│ │ │ ├── __init__.py
│ │ │ └── parallel_state.py
│ │ ├── torch/
│ │ │ ├── __init__.py
│ │ │ └── distributed/
│ │ │ ├── __init__.py
│ │ │ ├── _state_dict_utils.py
│ │ │ └── checkpoint/
│ │ │ ├── __init__.py
│ │ │ └── state_dict.py
│ │ └── vllm/
│ │ └── __init__.py
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── base_tool.py
│ │ ├── geo3k_tool.py
│ │ ├── gsm8k_tool.py
│ │ ├── image_zoom_in_tool.py
│ │ ├── mcp_base_tool.py
│ │ ├── mcp_search_tool.py
│ │ ├── sandbox_fusion_tools.py
│ │ ├── schemas.py
│ │ ├── search_tool.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── mcp_clients/
│ │ │ ├── McpClientManager.py
│ │ │ └── utils.py
│ │ ├── search_r1_like_utils.py
│ │ └── tool_registry.py
│ ├── trainer/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ ├── __init__.py
│ │ │ ├── _generated_ppo_megatron_trainer.yaml
│ │ │ ├── _generated_ppo_trainer.yaml
│ │ │ ├── actor/
│ │ │ │ ├── actor.yaml
│ │ │ │ ├── dp_actor.yaml
│ │ │ │ └── megatron_actor.yaml
│ │ │ ├── algorithm.py
│ │ │ ├── config.py
│ │ │ ├── critic/
│ │ │ │ ├── critic.yaml
│ │ │ │ ├── dp_critic.yaml
│ │ │ │ └── megatron_critic.yaml
│ │ │ ├── data/
│ │ │ │ └── legacy_data.yaml
│ │ │ ├── engine/
│ │ │ │ ├── fsdp.yaml
│ │ │ │ └── megatron.yaml
│ │ │ ├── evaluation.yaml
│ │ │ ├── generation.yaml
│ │ │ ├── model/
│ │ │ │ └── hf_model.yaml
│ │ │ ├── npu_profile/
│ │ │ │ └── npu_profile.yaml
│ │ │ ├── optim/
│ │ │ │ ├── fsdp.yaml
│ │ │ │ └── megatron.yaml
│ │ │ ├── ppo_megatron_trainer.yaml
│ │ │ ├── ppo_trainer.yaml
│ │ │ ├── ref/
│ │ │ │ ├── dp_ref.yaml
│ │ │ │ ├── megatron_ref.yaml
│ │ │ │ └── ref.yaml
│ │ │ ├── reward_model/
│ │ │ │ ├── dp_reward_model.yaml
│ │ │ │ ├── megatron_reward_model.yaml
│ │ │ │ └── reward_model.yaml
│ │ │ ├── rollout/
│ │ │ │ └── rollout.yaml
│ │ │ ├── sft_trainer.yaml
│ │ │ └── sft_trainer_engine.yaml
│ │ ├── constants_ppo.py
│ │ ├── fsdp_sft_trainer.py
│ │ ├── main_eval.py
│ │ ├── main_generation.py
│ │ ├── main_generation_server.py
│ │ ├── main_ppo.py
│ │ ├── ppo/
│ │ │ ├── __init__.py
│ │ │ ├── core_algos.py
│ │ │ ├── metric_utils.py
│ │ │ ├── mismatch_helper.py
│ │ │ ├── ray_trainer.py
│ │ │ ├── reward.py
│ │ │ └── utils.py
│ │ ├── runtime_env.yaml
│ │ └── sft_trainer.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── activation_offload.py
│ │ ├── attention_utils.py
│ │ ├── checkpoint/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_handler.py
│ │ │ ├── checkpoint_manager.py
│ │ │ ├── fsdp_checkpoint_manager.py
│ │ │ └── megatron_checkpoint_manager.py
│ │ ├── config.py
│ │ ├── dataset/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── dataset_utils.py
│ │ │ ├── multiturn_sft_dataset.py
│ │ │ ├── onerec_dataset.py
│ │ │ ├── rl_dataset.py
│ │ │ ├── rm_dataset.py
│ │ │ ├── sft_dataset.py
│ │ │ └── vision_utils.py
│ │ ├── debug/
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ ├── performance.py
│ │ │ └── trajectory_tracker.py
│ │ ├── device.py
│ │ ├── distributed.py
│ │ ├── experimental/
│ │ │ ├── __init__.py
│ │ │ └── torch_functional.py
│ │ ├── flops_counter.py
│ │ ├── fs.py
│ │ ├── fsdp_utils.py
│ │ ├── groupwise.py
│ │ ├── hdfs_io.py
│ │ ├── import_utils.py
│ │ ├── kernel/
│ │ │ ├── __init__.py
│ │ │ ├── kernels.py
│ │ │ └── linear_cross_entropy.py
│ │ ├── logger/
│ │ │ ├── __init__.py
│ │ │ └── aggregate_logger.py
│ │ ├── logging_utils.py
│ │ ├── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── dist_checkpointing.py
│ │ │ ├── memory.py
│ │ │ ├── optimizer.py
│ │ │ ├── pipeline_parallel.py
│ │ │ ├── sequence_parallel.py
│ │ │ └── tensor_parallel.py
│ │ ├── megatron_utils.py
│ │ ├── memory_buffer.py
│ │ ├── memory_utils.py
│ │ ├── metric/
│ │ │ ├── __init__.py
│ │ │ └── utils.py
│ │ ├── model.py
│ │ ├── net_utils.py
│ │ ├── npu_utils.py
│ │ ├── profiler/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── empty_annotations.py
│ │ │ ├── mstx_profile.py
│ │ │ ├── nvtx_profile.py
│ │ │ ├── performance.py
│ │ │ └── profile.py
│ │ ├── py_functional.py
│ │ ├── ray_utils.py
│ │ ├── rendezvous/
│ │ │ ├── __init__.py
│ │ │ └── ray_backend.py
│ │ ├── reward_score/
│ │ │ ├── __init__.py
│ │ │ ├── geo3k.py
│ │ │ ├── gsm8k.py
│ │ │ ├── math_batch.py
│ │ │ ├── math_dapo.py
│ │ │ ├── math_reward.py
│ │ │ ├── math_verify.py
│ │ │ ├── prime_code/
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── testing_util.py
│ │ │ │ └── utils.py
│ │ │ ├── prime_math/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── grader.py
│ │ │ │ └── math_normalize.py
│ │ │ ├── sandbox_fusion/
│ │ │ │ ├── __init__.py
│ │ │ │ └── utils.py
│ │ │ └── search_r1_like_qa_em.py
│ │ ├── rollout_skip.py
│ │ ├── rollout_trace.py
│ │ ├── seqlen_balancing.py
│ │ ├── tensordict_utils.py
│ │ ├── tokenizer.py
│ │ ├── torch_dtypes.py
│ │ ├── torch_functional.py
│ │ ├── tracking.py
│ │ ├── transferqueue_utils.py
│ │ ├── transformers_compat.py
│ │ ├── ulysses.py
│ │ └── vllm/
│ │ ├── __init__.py
│ │ ├── patch.py
│ │ └── utils.py
│ ├── version/
│ │ └── version
│ └── workers/
│ ├── __init__.py
│ ├── actor/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── dp_actor.py
│ │ └── megatron_actor.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── actor.py
│ │ ├── critic.py
│ │ ├── engine.py
│ │ ├── model.py
│ │ ├── optimizer.py
│ │ ├── reward_model.py
│ │ └── rollout.py
│ ├── critic/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── dp_critic.py
│ │ └── megatron_critic.py
│ ├── engine/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── fsdp/
│ │ │ ├── __init__.py
│ │ │ ├── transformer_impl.py
│ │ │ └── utils.py
│ │ ├── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── transformer_impl.py
│ │ │ └── utils.py
│ │ ├── mindspeed/
│ │ │ ├── __init__.py
│ │ │ └── transformer_impl.py
│ │ └── utils.py
│ ├── fsdp_workers.py
│ ├── megatron_workers.py
│ ├── reward_manager/
│ │ ├── __init__.py
│ │ ├── abstract.py
│ │ ├── batch.py
│ │ ├── dapo.py
│ │ ├── naive.py
│ │ ├── prime.py
│ │ └── registry.py
│ ├── reward_model/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── megatron/
│ │ ├── __init__.py
│ │ └── reward_model.py
│ ├── roles/
│ │ ├── __init__.py
│ │ ├── actor.py
│ │ ├── critic.py
│ │ ├── hybrid_engine.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── losses.py
│ │ └── padding.py
│ ├── rollout/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── hf_rollout.py
│ │ ├── naive/
│ │ │ ├── __init__.py
│ │ │ └── naive_rollout.py
│ │ ├── replica.py
│ │ ├── schemas.py
│ │ ├── sglang_rollout/
│ │ │ ├── __init__.py
│ │ │ ├── async_sglang_server.py
│ │ │ ├── http_server_engine.py
│ │ │ ├── sglang_rollout.py
│ │ │ └── utils.py
│ │ ├── tokenizer.py
│ │ ├── utils.py
│ │ └── vllm_rollout/
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── vllm_async_server.py
│ │ └── vllm_rollout_spmd.py
│ └── sharding_manager/
│ ├── __init__.py
│ ├── base.py
│ ├── fsdp_sglang.py
│ ├── fsdp_ulysses.py
│ ├── fsdp_vllm.py
│ ├── megatron_sglang.py
│ └── megatron_vllm.py
└── verl_rl/
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── README_ORIGINAL.md
├── deploy_env.sh
├── docker/
│ ├── Apptainerfile.rocm
│ ├── Dockerfile.extention.awsefa
│ ├── Dockerfile.ngc.vllm
│ ├── Dockerfile.ngc.vllm0.8
│ ├── Dockerfile.ngc.vllm0.8.sagemaker
│ ├── Dockerfile.rocm
│ ├── Dockerfile.rocm_verl-0.3.0.post1
│ ├── Dockerfile.rocm_verl-0.4.1
│ ├── Dockerfile.sglang
│ ├── Dockerfile.vemlp.vllm.te
│ ├── Dockerfile.vllm.sglang.megatron.deepseek
│ ├── README.md
│ ├── verl0.4-cu124-torch2.6-fa2.7.4/
│ │ ├── Dockerfile.app.sglang.vllm.mcore0.12
│ │ ├── Dockerfile.app.sglang.vllm.mcore0.12.deepep
│ │ ├── Dockerfile.app.sglang.vllm.mcore0.13.preview
│ │ ├── Dockerfile.app.vllm.mcore0.12
│ │ ├── Dockerfile.app.vllm.mcore0.12.deepep
│ │ ├── Dockerfile.app.vllm.mcore0.13.preview
│ │ ├── Dockerfile.base
│ │ └── README.md
│ ├── verl0.5-cu126-torch2.7-fa2.7.4/
│ │ ├── Dockerfile.app.sglang.mcore0.12
│ │ ├── Dockerfile.app.vllm.mcore0.12
│ │ ├── Dockerfile.base.torch2.7.0
│ │ ├── Dockerfile.base.torch2.7.1
│ │ └── README.md
│ ├── verl0.5-cu126-torch2.7.1-fa2.8.0/
│ │ ├── Dockerfile.app.sglang.mcore0.12
│ │ ├── Dockerfile.app.sglang.mcore0.13.preview
│ │ ├── Dockerfile.base
│ │ └── README.md
│ └── verl0.5-preview-cu128-torch2.7.1-fa2.8.0/
│ ├── Dockerfile.app.sglang.megatron
│ ├── Dockerfile.base
│ └── README.md
├── docs/
│ ├── Makefile
│ ├── README.md
│ ├── README_vllm0.7.md
│ ├── README_vllm0.8.md
│ ├── _static/
│ │ └── js/
│ │ └── runllm-widget.js
│ ├── advance/
│ │ ├── agent_loop.rst
│ │ ├── checkpoint.rst
│ │ ├── dpo_extension.rst
│ │ ├── fsdp_extension.rst
│ │ ├── megatron_extension.rst
│ │ ├── one_step_off.md
│ │ ├── placement.rst
│ │ ├── ppo_lora.rst
│ │ ├── rollout_trace.rst
│ │ └── rope.rst
│ ├── algo/
│ │ ├── baseline.md
│ │ ├── dapo.md
│ │ ├── entropy.md
│ │ ├── gpg.md
│ │ ├── grpo.md
│ │ ├── opo.md
│ │ ├── ppo.md
│ │ ├── spin.md
│ │ └── sppo.md
│ ├── amd_tutorial/
│ │ ├── amd_build_dockerfile_page.rst
│ │ └── amd_vllm_page.rst
│ ├── api/
│ │ ├── data.rst
│ │ ├── single_controller.rst
│ │ ├── trainer.rst
│ │ └── utils.rst
│ ├── ascend_tutorial/
│ │ ├── ascend_profiling.rst
│ │ ├── ascend_profiling_en.rst
│ │ └── ascend_quick_start.rst
│ ├── conf.py
│ ├── examples/
│ │ ├── config.rst
│ │ ├── gsm8k_example.rst
│ │ ├── multi_modal_example.rst
│ │ ├── ppo_code_architecture.rst
│ │ └── sandbox_fusion_example.rst
│ ├── faq/
│ │ └── faq.rst
│ ├── hybrid_flow.rst
│ ├── index.rst
│ ├── perf/
│ │ ├── device_tuning.rst
│ │ ├── dpsk.md
│ │ ├── nsight_profiling.md
│ │ └── perf_tuning.rst
│ ├── preparation/
│ │ ├── prepare_data.rst
│ │ └── reward_function.rst
│ ├── requirements-docs.txt
│ ├── sglang_multiturn/
│ │ ├── interaction_system.rst
│ │ ├── multiturn.rst
│ │ ├── sandbox_fusion.rst
│ │ └── search_tool_example.rst
│ ├── single_controller.rst
│ ├── start/
│ │ ├── agentic_rl.rst
│ │ ├── install.rst
│ │ ├── more_resources.rst
│ │ ├── multinode.rst
│ │ ├── quickstart.rst
│ │ └── ray_debug_tutorial.rst
│ └── workers/
│ ├── fsdp_workers.rst
│ ├── megatron_workers.rst
│ ├── ray_trainer.rst
│ └── sglang_worker.rst
├── examples/
│ ├── data_preprocess/
│ │ ├── aime2024_multiturn_w_tool.py
│ │ ├── dapo_multiturn_w_tool.py
│ │ ├── full_hh_rlhf.py
│ │ ├── geo3k.py
│ │ ├── geo3k_multiturn_w_tool.py
│ │ ├── gsm8k.py
│ │ ├── gsm8k_multiturn_w_interaction.py
│ │ ├── gsm8k_multiturn_w_tool.py
│ │ ├── gsm8k_tool_agent_loop.py
│ │ ├── hellaswag.py
│ │ ├── math_dataset.py
│ │ ├── multiturn.py
│ │ └── preprocess_search_r1_dataset.py
│ ├── generation/
│ │ ├── run_deepseek7b_mutli_node.sh
│ │ └── run_deepseek_v2_lite_math.sh
│ ├── gpg_trainer/
│ │ ├── gpg.md
│ │ ├── run_qwen2-7b_math.sh
│ │ └── run_qwen2-7b_math_megatron.sh
│ ├── grpo_trainer/
│ │ ├── README.md
│ │ ├── run_deepseek671b_math_megatron.sh
│ │ ├── run_deepseek7b_llm.sh
│ │ ├── run_deepseek7b_llm_math.sh
│ │ ├── run_deepseek7b_llm_math_megatron.sh
│ │ ├── run_deepseek7b_llm_seq_balance.sh
│ │ ├── run_minicpmo2_6.sh
│ │ ├── run_moonlight16b_math_megatron.sh
│ │ ├── run_qwen2-7b.sh
│ │ ├── run_qwen2-7b_math.sh
│ │ ├── run_qwen2-7b_math_megatron.sh
│ │ ├── run_qwen2-7b_seq_balance.sh
│ │ ├── run_qwen2-7b_seq_balance_math_megatron.sh
│ │ ├── run_qwen2-7b_sgl_megatron.sh
│ │ ├── run_qwen2_5-3b_gsm8k_grpo_lora.sh
│ │ ├── run_qwen2_5-7b_math_megatron_diff_tp.sh
│ │ ├── run_qwen2_5_32b_grpo_npu.sh
│ │ ├── run_qwen2_5_7b_grpo_discrete_prof_npu.sh
│ │ ├── run_qwen2_5_7b_grpo_e2e_prof_npu.sh
│ │ ├── run_qwen2_5_7b_grpo_npu.sh
│ │ ├── run_qwen2_5_vl-7b-megatron.sh
│ │ ├── run_qwen2_5_vl-7b.sh
│ │ ├── run_qwen2_5_vl-7b_lora.sh
│ │ ├── run_qwen2_5_vl-7b_seq_balance.sh
│ │ ├── run_qwen2_5_vl_32b_npu.sh
│ │ ├── run_qwen2_5_vl_3b_npu.sh
│ │ ├── run_qwen2_5_vl_7b_npu.sh
│ │ ├── run_qwen3-236b_megatron.sh
│ │ ├── run_qwen3-8b.sh
│ │ └── run_qwen3moe-30b_megatron.sh
│ ├── ppo_trainer/
│ │ ├── README.md
│ │ ├── run_deepseek7b_llm.sh
│ │ ├── run_deepseek7b_llm_modelscope.sh
│ │ ├── run_deepseek7b_llm_pfppo.sh
│ │ ├── run_deepseek7b_llm_sandbox_fusion.sh
│ │ ├── run_deepseek7b_llm_sp2.sh
│ │ ├── run_deepseek_full_hh_rlhf.sh
│ │ ├── run_deepseek_math_gsm8k_megatron.sh
│ │ ├── run_deepseek_math_gsm8k_megatron_nsys.sh
│ │ ├── run_gemma.sh
│ │ ├── run_moonlight16b_a3b_gsm8k_megatron.sh
│ │ ├── run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh
│ │ ├── run_qwen2-7b_math_gsm8k_megatron.sh
│ │ ├── run_qwen2-7b_rm.sh
│ │ ├── run_qwen2-7b_rm_seq_balance.sh
│ │ ├── run_qwen2-7b_rm_seq_balance_fused_kernels.sh
│ │ ├── run_qwen2-7b_rm_seq_balance_nsys.sh
│ │ ├── run_qwen2-7b_seq_balance.sh
│ │ ├── run_qwen2-7b_sglang_seq_balance.sh
│ │ └── run_qwen2.5-32b.sh
│ ├── ray/
│ │ └── tutorial.ipynb
│ ├── reinforce_plus_plus_trainer/
│ │ ├── run_qwen2-7b_math_rf.sh
│ │ └── run_qwen2-7b_math_rf_baseline.sh
│ ├── remax_trainer/
│ │ ├── run_qwen2.5-3b_seq_balance.sh
│ │ └── run_qwen2.5-7b_seq_balance.sh
│ ├── rloo_trainer/
│ │ └── run_qwen2-7b.sh
│ ├── sft/
│ │ ├── gsm8k/
│ │ │ ├── run_deepseek_6b7.sh
│ │ │ ├── run_gemma_2b.sh
│ │ │ ├── run_gemma_7b.sh
│ │ │ ├── run_qwen2_5_05b_sft_peft_sp2_npu.sh
│ │ │ ├── run_qwen_05_peft.sh
│ │ │ ├── run_qwen_05_sp2.sh
│ │ │ └── run_qwen_05_sp2_liger.sh
│ │ └── multiturn/
│ │ └── run_qwen_05_sp2.sh
│ ├── sglang_multiturn/
│ │ ├── README.md
│ │ ├── config/
│ │ │ ├── geo3k_multiturn_grpo.yaml
│ │ │ ├── geo3k_multiturn_megatron_grpo.yaml
│ │ │ ├── gsm8k_multiturn_grpo.yaml
│ │ │ ├── gsm8k_multiturn_grpo_w_interaction.yaml
│ │ │ ├── gsm8k_multiturn_megatron_grpo.yaml
│ │ │ ├── interaction_config/
│ │ │ │ └── gsm8k_interaction_config.yaml
│ │ │ ├── retool_multiturn_grpo.yaml
│ │ │ ├── search_multiturn_grpo.yaml
│ │ │ └── tool_config/
│ │ │ ├── geo3k_tool_config.yaml
│ │ │ ├── gsm8k_tool_config.yaml
│ │ │ ├── mcp_server.json
│ │ │ ├── mcp_tool_config.yaml
│ │ │ ├── sandbox_fusion_tool_config.yaml
│ │ │ └── search_tool_config.yaml
│ │ ├── geo3k/
│ │ │ ├── run_qwen2.5-3b_geo3k_multiturn.sh
│ │ │ ├── run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh
│ │ │ └── run_qwen2.5-3b_megatron_geo3k_multiturn.sh
│ │ ├── run_qwen0.5b_gsm8k_multiturn_curriculum.sh
│ │ ├── run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh
│ │ ├── run_qwen2.5-3b_gsm8k_multiturn.sh
│ │ ├── run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh
│ │ ├── run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh
│ │ ├── run_qwen2.5-3b_megatron_gsm8k_multiturn.sh
│ │ ├── run_qwen3-4b_gsm8k_multiturn.sh
│ │ └── search_r1_like/
│ │ ├── local_dense_retriever/
│ │ │ ├── download.py
│ │ │ └── retrieval_server.py
│ │ └── run_qwen2.5-3b_instruct_search_multiturn.sh
│ ├── slurm/
│ │ └── ray_on_slurm.slurm
│ ├── split_placement/
│ │ ├── README.md
│ │ ├── config/
│ │ │ └── ppo_trainer_split.yaml
│ │ ├── main_ppo_split.py
│ │ ├── run_deepseek7b_llm.sh
│ │ └── split_monkey_patch.py
│ └── tuning/
│ ├── 0.5b/
│ │ └── qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh
│ ├── 1.5b/
│ │ └── qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh
│ ├── 14b/
│ │ ├── qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh
│ │ └── qwen2_14b_grpo_4_h800_fsdp_vllm.sh
│ ├── 32b/
│ │ ├── qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh
│ │ └── qwen2_32B_grpo_8_h20_megatron_vllm.sh
│ ├── 3b/
│ │ └── qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh
│ ├── 70b/
│ │ ├── qwen2-70b_grpo_32_h20_fsdp_vllm.sh
│ │ ├── qwen2-70b_grpo_32_h800_fsdp_vllm.sh
│ │ └── qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh
│ └── 7b/
│ ├── qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh
│ └── qwen2-7b_grpo_2_h800_fsdp_vllm.sh
├── init_ray.sh
├── init_ray_cluster.sh
├── pyproject.toml
├── recipe/
│ ├── README.md
│ ├── char_count/
│ │ ├── README.md
│ │ ├── create_dataset.py
│ │ ├── reward_function.py
│ │ ├── train_grpo.sh
│ │ └── train_sft.sh
│ ├── dapo/
│ │ ├── README.md
│ │ ├── config/
│ │ │ └── dapo_trainer.yaml
│ │ ├── dapo_ray_trainer.py
│ │ ├── main_dapo.py
│ │ ├── prepare_dapo_data.sh
│ │ ├── run_dapo_early_qwen2.5_32b.sh
│ │ ├── run_dapo_qwen2.5_32b.sh
│ │ ├── run_dapo_wo_ds_qwen2.5_32b.sh
│ │ ├── runtime_env.yaml
│ │ ├── test_dapo_7b.sh
│ │ ├── test_dapo_7b_math.sh
│ │ ├── test_dapo_7b_math_lora.sh
│ │ ├── test_dapo_7b_math_megatron.sh
│ │ ├── test_dapo_dspk_671b_megatron.sh
│ │ ├── test_dapo_qwen3_30b_math.sh
│ │ └── test_dapo_qwen3_30b_math_single_node.sh
│ ├── entropy/
│ │ ├── 32b_clip_cov.sh
│ │ ├── 32b_kl_cov.sh
│ │ ├── 32b_kl_cov_mininbsz.sh
│ │ ├── 7b_clip_cov.sh
│ │ ├── 7b_kl_cov.sh
│ │ ├── README.md
│ │ ├── config/
│ │ │ └── entropy_trainer.yaml
│ │ ├── entropy_ray_trainer.py
│ │ ├── main_entropy.py
│ │ ├── reward.py
│ │ └── reward_score/
│ │ ├── __init__.py
│ │ └── entropy_math/
│ │ ├── __init__.py
│ │ ├── grader.py
│ │ └── math_normalize.py
│ ├── genrm_remote/
│ │ ├── README.md
│ │ ├── reward_function.py
│ │ └── run_genrm_remote.sh
│ ├── langgraph_agent/
│ │ ├── __init__.py
│ │ ├── chat_model.py
│ │ ├── example/
│ │ │ ├── README.md
│ │ │ ├── agent.yaml
│ │ │ ├── create_dataset.py
│ │ │ ├── math_expression.py
│ │ │ └── run_qwen2.5_3b.sh
│ │ ├── react_agent_loop.py
│ │ └── test_react_agent_loop.py
│ ├── minicpmo/
│ │ └── rl_dataset.py
│ ├── one_step_off_policy/
│ │ ├── README.md
│ │ ├── config/
│ │ │ ├── one_step_off_ppo_megatron_trainer.yaml
│ │ │ └── one_step_off_ppo_trainer.yaml
│ │ ├── dapo_7b_math_fsdp2_4_12.sh
│ │ ├── dapo_7b_math_fsdp2_colocate.sh
│ │ ├── dapo_7b_math_megatron_4_12.sh
│ │ ├── dapo_7b_math_megatron_colocate.sh
│ │ ├── fsdp_workers.py
│ │ ├── grpo_0.6b_gsm8k_fsdp2_2_6.sh
│ │ ├── grpo_3b_gsm8k_fsdp2_2_6.sh
│ │ ├── main_ppo.py
│ │ ├── megatron_workers.py
│ │ ├── ray_trainer.py
│ │ └── vllm_sharding_manager.py
│ ├── onerec/
│ │ ├── main_onerec_ppo.py
│ │ ├── onerec_fsdp_workers.py
│ │ ├── onerec_ray_trainer.py
│ │ ├── onerec_recipe.py
│ │ ├── onerec_vllm_rollout.py
│ │ └── run_grpo.sh
│ ├── prime/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ └── prime_trainer.yaml
│ │ ├── main_prime.py
│ │ ├── prime_core_algos.py
│ │ ├── prime_dp_rm.py
│ │ ├── prime_fsdp_workers.py
│ │ ├── prime_ray_trainer.py
│ │ ├── run_prime_qwen.sh
│ │ └── run_prime_qwen_code.sh
│ ├── r1/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ └── evaluation.yaml
│ │ ├── data_process.py
│ │ ├── main_eval.py
│ │ ├── reward_score.py
│ │ ├── run_r1_distill_qwen.sh
│ │ └── tasks/
│ │ ├── __init__.py
│ │ ├── gpqa.py
│ │ ├── livecodebench.py
│ │ └── math.py
│ ├── retool/
│ │ ├── retool.py
│ │ ├── retool_multi_turn_sft_preprocess.py
│ │ ├── retool_sft_preprocess.py
│ │ ├── run_qwen2-32b_sft.sh
│ │ ├── run_qwen2.5_32b_sp8.sh
│ │ ├── run_qwen2.5_7b_sp4.sh
│ │ ├── run_qwen3_4b_sp4.sh
│ │ └── sandbox_fusion_tool_config.yaml
│ ├── spin/
│ │ ├── README.md
│ │ ├── config/
│ │ │ └── spin_trainer.yaml
│ │ ├── core_algos.py
│ │ ├── dp_actor.py
│ │ ├── fsdp_workers.py
│ │ ├── main_spin.py
│ │ ├── run_spin.sh
│ │ └── spin_trainer.py
│ └── sppo/
│ ├── README.md
│ ├── __init__.py
│ ├── config/
│ │ └── sppo_trainer.yaml
│ ├── dp_actor.py
│ ├── main_sppo.py
│ ├── run_qwen2.5-7b_rm.sh
│ ├── sppo_ray_trainer.py
│ └── sppo_worker.py
├── requirements-npu.txt
├── requirements.txt
├── requirements_sglang.txt
├── scripts/
│ ├── __init__.py
│ ├── converter_hf_to_mcore.py
│ ├── diagnose.py
│ ├── generate_trainer_config.sh
│ ├── init_random_model.py
│ ├── install_vllm_sglang_mcore.sh
│ ├── legacy_model_merger.py
│ ├── print_cfg.py
│ └── rollout_viewer.py
├── setup.py
├── tests/
│ ├── README.md
│ ├── __init__.py
│ ├── experimental/
│ │ └── agent_loop/
│ │ ├── agent_utils.py
│ │ └── test_basic_agent_loop.py
│ ├── interactions/
│ │ ├── __init__.py
│ │ ├── test_gsm8k_interaction.py
│ │ └── test_interaction_registry.py
│ ├── kill_github_tests.sh
│ ├── models/
│ │ ├── test_transformer.py
│ │ └── test_transformers_ulysses.py
│ ├── single_controller/
│ │ ├── __init__.py
│ │ ├── base/
│ │ │ └── test_decorator.py
│ │ ├── check_worker_alive/
│ │ │ └── main.py
│ │ ├── detached_worker/
│ │ │ ├── README.md
│ │ │ ├── client.py
│ │ │ ├── run.sh
│ │ │ └── server.py
│ │ ├── test_auto_padding_on_cpu.py
│ │ ├── test_colocated_workers.py
│ │ ├── test_colocated_workers_fused.py
│ │ ├── test_data_transfer.py
│ │ ├── test_decorator_on_cpu.py
│ │ ├── test_driverfunc_to_worker.py
│ │ ├── test_fused_workers_on_cpu.py
│ │ ├── test_high_level_scheduling_api.py
│ │ ├── test_ray_collectives.py
│ │ ├── test_ray_local_envs_on_cpu.py
│ │ ├── test_ray_utils_on_cpu.py
│ │ ├── test_rvdz.py
│ │ ├── test_worker_group_basics.py
│ │ └── test_worker_group_torch.py
│ ├── special_distributed/
│ │ ├── README.md
│ │ ├── run_all.sh
│ │ ├── test_fsdp_ckpt.py
│ │ └── test_tensor_dict.py
│ ├── special_e2e/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── check_custom_rwd_fn.py
│ │ ├── check_results.py
│ │ ├── envs/
│ │ │ ├── __init__.py
│ │ │ └── digit_completion/
│ │ │ ├── __init__.py
│ │ │ ├── task.py
│ │ │ └── tokenizer.py
│ │ ├── generation/
│ │ │ └── run_gen_qwen05.sh
│ │ ├── ppo_trainer/
│ │ │ ├── expert_parallel/
│ │ │ │ └── qwen2moe_minimal.json
│ │ │ ├── run_function_reward.sh
│ │ │ ├── run_model_reward.sh
│ │ │ ├── run_single_gpu.sh
│ │ │ └── run_single_gpu_with_engine.sh
│ │ ├── run_dapo.sh
│ │ ├── run_genrm_remote.sh
│ │ ├── run_geo3k_fsdp_sgl_multiturn_w_tool.sh
│ │ ├── run_grpo_lora_with_merge.sh
│ │ ├── run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh
│ │ ├── run_gsm8k_fsdp_sgl_multiturn_w_tool.sh
│ │ ├── run_one_step_off_policy.sh
│ │ ├── run_ppo_trainer_megatron.sh
│ │ ├── run_prime.sh
│ │ ├── run_r1_distill_qwen_aime24_eval.sh
│ │ ├── run_spin.sh
│ │ ├── run_sppo.sh
│ │ ├── run_test.sh
│ │ └── sft/
│ │ ├── run_sft.sh
│ │ └── test_sp_loss_match.py
│ ├── special_npu/
│ │ ├── run_qwen2_5_05b_dapo.sh
│ │ ├── run_qwen2_5_05b_grpo.sh
│ │ ├── run_qwen2_5_05b_sft_peft_sp2.sh
│ │ └── run_qwen2_5_vl_3b_npu.sh
│ ├── special_sanity/
│ │ ├── check_api_docs.py
│ │ ├── check_device_api_usage.py
│ │ ├── check_docs_time_info.py
│ │ ├── check_docstrings.py
│ │ ├── check_license.py
│ │ ├── check_pr_description.py
│ │ ├── check_pr_title.py
│ │ ├── test_config_docs.py
│ │ ├── test_import.py
│ │ ├── type_coverage_check.py
│ │ ├── validate_imported_docs.py
│ │ └── validate_structure.py
│ ├── special_standalone/
│ │ ├── README.md
│ │ └── test_memory_buffers.py
│ ├── test_base_config_on_cpu.py
│ ├── test_protocol_on_cpu.py
│ ├── tools/
│ │ └── test_base_tool_on_cpu.py
│ ├── trainer/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ ├── __init__.py
│ │ │ ├── legacy_ppo_megatron_trainer.yaml
│ │ │ ├── legacy_ppo_trainer.yaml
│ │ │ ├── test_algo_config_on_cpu.py
│ │ │ ├── test_critic_config_on_cpu.py
│ │ │ └── test_legacy_config_on_cpu.py
│ │ └── ppo/
│ │ ├── __init__.py
│ │ ├── test_core_algos_on_cpu.py
│ │ └── test_metric_utils_on_cpu.py
│ ├── utils/
│ │ ├── _test_module.py
│ │ ├── dataset/
│ │ │ ├── test_create_rl_sampler_on_cpu.py
│ │ │ ├── test_multiturn_sft_dataset_on_cpu.py
│ │ │ ├── test_rl_dataset_on_cpu.py
│ │ │ └── test_sft_dataset_on_cpu.py
│ │ ├── megatron/
│ │ │ └── test_pipeline_parallel.py
│ │ ├── reward_score/
│ │ │ ├── reward_score/
│ │ │ │ └── test_sandbox_fusion_on_cpu.py
│ │ │ └── test_sandbox_on_cpu.py
│ │ ├── test_activation_offload.py
│ │ ├── test_config_on_cpu.py
│ │ ├── test_flops_counter.py
│ │ ├── test_fs_on_cpu.py
│ │ ├── test_import_utils_on_cpu.py
│ │ ├── test_linear_cross_entropy.py
│ │ ├── test_linear_cross_entropy_tp.py
│ │ ├── test_model_on_cpu.py
│ │ ├── test_nvtx_profile.py
│ │ ├── test_rollout_trace_on_cpu.py
│ │ ├── test_seqlen_balancing.py
│ │ ├── test_temp_env_on_cpu.py
│ │ ├── test_timeout_decorator_cpu.py
│ │ └── test_torch_functional.py
│ └── workers/
│ ├── reward_manager/
│ │ └── test_registry_on_cpu.py
│ └── rollout/
│ ├── async_rollout_utils.py
│ ├── perf/
│ │ └── vllm_async_rollout.py
│ ├── resource/
│ │ └── tool_configs/
│ │ ├── mcp_server.json
│ │ ├── mcp_tool_config
│ │ ├── sandbox_fusion_tool_config
│ │ └── search_tool_config
│ ├── rollout_vllm/
│ │ ├── run_fsdp_vllm.py
│ │ ├── test_vllm_chat_scheduler.py
│ │ ├── test_vllm_model_rope_scaling.py
│ │ └── test_vllm_spmd.py
│ ├── test_async_sglang_server_on_cpu.py
│ ├── test_custom_completion_callback.py
│ ├── test_hf_rollout.py
│ ├── test_sglang_async_rollout_mcp_tools.py
│ ├── test_sglang_async_rollout_multimodal_delta.py
│ ├── test_sglang_async_rollout_search_tools.py
│ ├── test_sglang_async_rollout_sf_tools.py
│ ├── test_sglang_async_rollout_w_interaction.py
│ ├── test_sglang_async_rollout_w_tools.py
│ ├── test_sglang_multi_interaction.py
│ ├── test_sglang_rollout_sharding_manager.py
│ ├── test_sglang_spmd.py
│ └── utils_sglang.py
└── verl/
├── __init__.py
├── base_config.py
├── experimental/
│ ├── __init__.py
│ ├── agent_loop/
│ │ ├── __init__.py
│ │ ├── agent_loop.py
│ │ ├── single_turn_agent_loop.py
│ │ ├── tool_agent_loop.py
│ │ └── tool_parser.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ └── sampler.py
│ └── dynamic_dataset/
│ ├── __init__.py
│ └── dynamicgen_dataset.py
├── interactions/
│ ├── __init__.py
│ ├── base.py
│ ├── gsm8k_interaction.py
│ └── utils/
│ ├── __init__.py
│ └── interaction_registry.py
├── model_merger/
│ ├── __init__.py
│ ├── __main__.py
│ ├── base_model_merger.py
│ ├── fsdp_model_merger.py
│ └── megatron_model_merger.py
├── models/
│ ├── README.md
│ ├── __init__.py
│ ├── llama/
│ │ ├── __init__.py
│ │ └── megatron/
│ │ ├── __init__.py
│ │ ├── checkpoint_utils/
│ │ │ ├── __init__.py
│ │ │ ├── llama_loader.py
│ │ │ ├── llama_loader_depracated.py
│ │ │ └── llama_saver.py
│ │ ├── layers/
│ │ │ ├── __init__.py
│ │ │ ├── parallel_attention.py
│ │ │ ├── parallel_decoder.py
│ │ │ ├── parallel_linear.py
│ │ │ ├── parallel_mlp.py
│ │ │ └── parallel_rmsnorm.py
│ │ └── modeling_llama_megatron.py
│ ├── mcore/
│ │ ├── __init__.py
│ │ ├── config_converter.py
│ │ ├── loader.py
│ │ ├── mbridge.py
│ │ ├── model_forward.py
│ │ ├── model_forward_fused.py
│ │ ├── model_initializer.py
│ │ ├── patch_v012.py
│ │ ├── qwen2_5_vl/
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── model.py
│ │ │ ├── rope_utils.py
│ │ │ ├── vision_config.py
│ │ │ ├── vision_model.py
│ │ │ └── vision_transformer_block.py
│ │ ├── readme.md
│ │ ├── registry.py
│ │ ├── saver.py
│ │ ├── util.py
│ │ └── weight_converter.py
│ ├── qwen2/
│ │ ├── __init__.py
│ │ └── megatron/
│ │ ├── __init__.py
│ │ ├── checkpoint_utils/
│ │ │ ├── __init__.py
│ │ │ ├── qwen2_loader.py
│ │ │ ├── qwen2_loader_depracated.py
│ │ │ └── qwen2_saver.py
│ │ ├── layers/
│ │ │ ├── __init__.py
│ │ │ ├── parallel_attention.py
│ │ │ ├── parallel_decoder.py
│ │ │ ├── parallel_linear.py
│ │ │ ├── parallel_mlp.py
│ │ │ └── parallel_rmsnorm.py
│ │ └── modeling_qwen2_megatron.py
│ ├── registry.py
│ ├── transformers/
│ │ ├── __init__.py
│ │ ├── dense_common.py
│ │ ├── kimi_vl.py
│ │ ├── llama.py
│ │ ├── monkey_patch.py
│ │ ├── npu_patch.py
│ │ ├── qwen2.py
│ │ ├── qwen2_5_vl.py
│ │ └── qwen2_vl.py
│ └── weight_loader_registry.py
├── protocol.py
├── py.typed
├── single_controller/
│ ├── __init__.py
│ ├── base/
│ │ ├── __init__.py
│ │ ├── decorator.py
│ │ ├── megatron/
│ │ │ ├── __init__.py
│ │ │ ├── worker.py
│ │ │ └── worker_group.py
│ │ ├── register_center/
│ │ │ ├── __init__.py
│ │ │ └── ray.py
│ │ ├── worker.py
│ │ └── worker_group.py
│ └── ray/
│ ├── __init__.py
│ ├── base.py
│ └── megatron.py
├── third_party/
│ ├── __init__.py
│ ├── sglang/
│ │ ├── __init__.py
│ │ └── parallel_state.py
│ ├── torch/
│ │ ├── __init__.py
│ │ └── distributed/
│ │ ├── __init__.py
│ │ ├── _state_dict_utils.py
│ │ └── checkpoint/
│ │ ├── __init__.py
│ │ └── state_dict.py
│ └── vllm/
│ └── __init__.py
├── tools/
│ ├── __init__.py
│ ├── base_tool.py
│ ├── geo3k_tool.py
│ ├── gsm8k_tool.py
│ ├── mcp_base_tool.py
│ ├── mcp_search_tool.py
│ ├── sandbox_fusion_tools.py
│ ├── schemas.py
│ ├── search_tool.py
│ └── utils/
│ ├── __init__.py
│ ├── mcp_clients/
│ │ ├── McpClientManager.py
│ │ └── utils.py
│ ├── search_r1_like_utils.py
│ └── tool_registry.py
├── trainer/
│ ├── __init__.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── _generated_ppo_megatron_trainer.yaml
│ │ ├── _generated_ppo_trainer.yaml
│ │ ├── actor/
│ │ │ ├── actor.yaml
│ │ │ ├── dp_actor.yaml
│ │ │ └── megatron_actor.yaml
│ │ ├── algorithm.py
│ │ ├── config.py
│ │ ├── critic/
│ │ │ ├── critic.yaml
│ │ │ ├── dp_critic.yaml
│ │ │ └── megatron_critic.yaml
│ │ ├── data/
│ │ │ └── legacy_data.yaml
│ │ ├── evaluation.yaml
│ │ ├── generation.yaml
│ │ ├── npu_profile/
│ │ │ └── npu_profile.yaml
│ │ ├── ppo_megatron_trainer.yaml
│ │ ├── ppo_trainer.yaml
│ │ ├── ref/
│ │ │ ├── dp_ref.yaml
│ │ │ ├── megatron_ref.yaml
│ │ │ └── ref.yaml
│ │ ├── reward_model/
│ │ │ ├── dp_reward_model.yaml
│ │ │ ├── megatron_reward_model.yaml
│ │ │ └── reward_model.yaml
│ │ ├── rollout/
│ │ │ └── rollout.yaml
│ │ └── sft_trainer.yaml
│ ├── constants_ppo.py
│ ├── fsdp_sft_trainer.py
│ ├── main_eval.py
│ ├── main_generation.py
│ ├── main_ppo.py
│ ├── ppo/
│ │ ├── __init__.py
│ │ ├── core_algos.py
│ │ ├── metric_utils.py
│ │ ├── ray_trainer.py
│ │ └── reward.py
│ └── runtime_env.yaml
├── utils/
│ ├── __init__.py
│ ├── activation_offload.py
│ ├── checkpoint/
│ │ ├── __init__.py
│ │ ├── checkpoint_manager.py
│ │ ├── fsdp_checkpoint_manager.py
│ │ └── megatron_checkpoint_manager.py
│ ├── config.py
│ ├── dataset/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── multiturn_sft_dataset.py
│ │ ├── rl_dataset.py
│ │ ├── rm_dataset.py
│ │ ├── sft_dataset.py
│ │ └── vision_utils.py
│ ├── debug/
│ │ ├── __init__.py
│ │ ├── performance.py
│ │ └── trajectory_tracker.py
│ ├── device.py
│ ├── distributed.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ └── torch_functional.py
│ ├── flops_counter.py
│ ├── fs.py
│ ├── fsdp_utils.py
│ ├── hdfs_io.py
│ ├── import_utils.py
│ ├── kernel/
│ │ ├── __init__.py
│ │ ├── kernels.py
│ │ └── linear_cross_entropy.py
│ ├── logger/
│ │ ├── __init__.py
│ │ └── aggregate_logger.py
│ ├── logging_utils.py
│ ├── megatron/
│ │ ├── __init__.py
│ │ ├── dist_checkpointing.py
│ │ ├── memory.py
│ │ ├── optimizer.py
│ │ ├── pipeline_parallel.py
│ │ ├── sequence_parallel.py
│ │ └── tensor_parallel.py
│ ├── megatron_utils.py
│ ├── memory_buffer.py
│ ├── metric/
│ │ ├── __init__.py
│ │ └── utils.py
│ ├── model.py
│ ├── net_utils.py
│ ├── profiler/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── empty_annotations.py
│ │ ├── mstx_profile.py
│ │ ├── nvtx_profile.py
│ │ ├── performance.py
│ │ └── profile.py
│ ├── py_functional.py
│ ├── ray_utils.py
│ ├── rendezvous/
│ │ ├── __init__.py
│ │ └── ray_backend.py
│ ├── reward_score/
│ │ ├── __init__.py
│ │ ├── geo3k.py
│ │ ├── gsm8k.py
│ │ ├── math.py
│ │ ├── math_batch.py
│ │ ├── math_dapo.py
│ │ ├── math_verify.py
│ │ ├── prime_code/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── testing_util.py
│ │ │ └── utils.py
│ │ ├── prime_math/
│ │ │ ├── __init__.py
│ │ │ ├── grader.py
│ │ │ └── math_normalize.py
│ │ ├── sandbox_fusion/
│ │ │ ├── __init__.py
│ │ │ └── utils.py
│ │ └── search_r1_like_qa_em.py
│ ├── rollout_trace.py
│ ├── seqlen_balancing.py
│ ├── tokenizer.py
│ ├── torch_dtypes.py
│ ├── torch_functional.py
│ ├── tracking.py
│ ├── ulysses.py
│ └── vllm_utils.py
├── version/
│ └── version
└── workers/
├── __init__.py
├── actor/
│ ├── __init__.py
│ ├── base.py
│ ├── dp_actor.py
│ └── megatron_actor.py
├── critic/
│ ├── __init__.py
│ ├── base.py
│ ├── dp_critic.py
│ └── megatron_critic.py
├── engine/
│ ├── __init__.py
│ ├── base.py
│ ├── fsdp/
│ │ ├── __init__.py
│ │ ├── engine_impl.py
│ │ └── utils.py
│ └── megatron/
│ ├── __init__.py
│ └── engine_impl.py
├── fsdp_workers.py
├── megatron_workers.py
├── reward_manager/
│ ├── __init__.py
│ ├── batch.py
│ ├── dapo.py
│ ├── naive.py
│ ├── prime.py
│ └── registry.py
├── reward_model/
│ ├── __init__.py
│ ├── base.py
│ └── megatron/
│ ├── __init__.py
│ └── reward_model.py
├── roles/
│ ├── __init__.py
│ ├── actor.py
│ └── critic.py
├── rollout/
│ ├── __init__.py
│ ├── async_server.py
│ ├── base.py
│ ├── chat_scheduler.py
│ ├── hf_rollout.py
│ ├── naive/
│ │ ├── __init__.py
│ │ └── naive_rollout.py
│ ├── schemas.py
│ ├── sglang_rollout/
│ │ ├── __init__.py
│ │ ├── async_sglang_server.py
│ │ ├── sglang_rollout.py
│ │ └── utils.py
│ ├── tokenizer.py
│ └── vllm_rollout/
│ ├── __init__.py
│ ├── vllm_async_server.py
│ └── vllm_rollout_spmd.py
└── sharding_manager/
├── __init__.py
├── base.py
├── fsdp_sglang.py
├── fsdp_ulysses.py
├── fsdp_vllm.py
├── megatron_sglang.py
└── megatron_vllm.py
Showing preview only (694K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (7561 symbols across 912 files)
FILE: benchmarks/api/__init__.py
function load_config (line 23) | def load_config(config_path: str = None) -> Dict[str, Any]:
function get_client (line 49) | def get_client(model: str, **config) -> BaseLLMClient:
function get_client_from_config (line 80) | def get_client_from_config(
function batch_generate (line 114) | def batch_generate(
FILE: benchmarks/api/base.py
class BaseLLMClient (line 12) | class BaseLLMClient(ABC):
method __init__ (line 20) | def __init__(self, **config):
method _setup (line 33) | def _setup(self):
method _call_api (line 38) | def _call_api(
method _is_retryable_error (line 62) | def _is_retryable_error(self, error_msg: str) -> bool:
method _generate_with_retry (line 79) | def _generate_with_retry(
method generate (line 130) | def generate(
method batch_generate (line 155) | def batch_generate(
method __repr__ (line 238) | def __repr__(self) -> str:
FILE: benchmarks/api/claude.py
class ClaudeClient (line 10) | class ClaudeClient(BaseLLMClient):
method _setup (line 22) | def _setup(self):
method _call_api (line 39) | def _call_api(
FILE: benchmarks/api/deepseek.py
class DeepSeekClient (line 10) | class DeepSeekClient(BaseLLMClient):
method _setup (line 24) | def _setup(self):
method _call_api (line 44) | def _call_api(
FILE: benchmarks/api/example.py
function example1_use_config (line 9) | def example1_use_config():
function example2_direct_params (line 28) | def example2_direct_params():
function example3_batch_generate (line 61) | def example3_batch_generate():
function example4_custom_params (line 97) | def example4_custom_params():
function example5_error_handling (line 127) | def example5_error_handling():
function example6_switch_models (line 154) | def example6_switch_models():
function example7_user_portrait (line 177) | def example7_user_portrait():
function example8_direct_import (line 215) | def example8_direct_import():
function main (line 242) | def main():
FILE: benchmarks/api/gemini.py
class GeminiClient (line 12) | class GeminiClient(BaseLLMClient):
method _setup (line 26) | def _setup(self):
method _call_api (line 44) | def _call_api(
FILE: benchmarks/benchmark/base_generator.py
class Generator (line 13) | class Generator(ABC):
method __init__ (line 21) | def __init__(
method __str__ (line 33) | def __str__(self) -> str:
method generate (line 45) | def generate(
method get_hardware_info (line 145) | def get_hardware_info(self) -> Dict[str, Any]:
method _generate_two_stage_with_thinking (line 184) | def _generate_two_stage_with_thinking(
method _generate_two_stage_classification_with_thinking (line 352) | def _generate_two_stage_classification_with_thinking(
class HfTransformersMixin (line 483) | class HfTransformersMixin:
method _build_sampling_params (line 491) | def _build_sampling_params(self, **kwargs) -> tuple:
class VllmMixin (line 544) | class VllmMixin:
method _build_sampling_params (line 552) | def _build_sampling_params(self, **kwargs):
method _should_enable_optimizations (line 602) | def _should_enable_optimizations(self) -> bool:
class RayMixin (line 628) | class RayMixin:
method _initialize_ray_cluster (line 637) | def _initialize_ray_cluster(self):
method _determine_gpu_ids_from_cluster (line 688) | def _determine_gpu_ids_from_cluster(self) -> List[Dict[str, Any]]:
method _group_gpus_for_workers (line 751) | def _group_gpus_for_workers(
method _display_cluster_info (line 826) | def _display_cluster_info(self, gpu_list: List[Dict[str, Any]], num_wo...
method cleanup (line 881) | def cleanup(self):
FILE: benchmarks/benchmark/benchmark.py
class DataLoaderWrapper (line 20) | class DataLoaderWrapper:
method __init__ (line 22) | def __init__(self, model_path: str, benchmark_version: str, data_dir: ...
method _create_tokenizer (line 31) | def _create_tokenizer(self, model_path: str):
method load_data (line 44) | def load_data(self, task_name: str, split: str = "test", sample_size: ...
class Benchmark (line 59) | class Benchmark:
method __init__ (line 79) | def __init__(
method print_benchmark_table (line 100) | def print_benchmark_table():
method check_generator (line 148) | def check_generator(generator):
method run (line 157) | def run(
method _evaluate_single_task (line 249) | def _evaluate_single_task(
method _create_debug_file (line 329) | def _create_debug_file(generation_file: str, gen_data: Dict[str, Any],...
method _calculate_model_total_time (line 352) | def _calculate_model_total_time(model_results: Dict[str, Any]) -> float:
method _save_results_as_json (line 363) | def _save_results_as_json(eval_results: Dict[str, Any], output_path: s...
method _load_existing_results (line 372) | def _load_existing_results(output_path: str, task_types: List[str] = N...
method evaluate_dev (line 391) | def evaluate_dev(
FILE: benchmarks/benchmark/checkpoint_utils.py
function match_checkpoint_keys_to_model (line 15) | def match_checkpoint_keys_to_model(
function check_embedding_weight_sharing (line 69) | def check_embedding_weight_sharing(
function handle_weight_tying (line 121) | def handle_weight_tying(
function load_weights_from_pt (line 174) | def load_weights_from_pt(
function build_model_from_pt (line 264) | def build_model_from_pt(
function build_model_from_hf (line 316) | def build_model_from_hf(
function export_pt_to_safetensor (line 355) | def export_pt_to_safetensor(
FILE: benchmarks/benchmark/generation_runner.py
class GenerationRunner (line 23) | class GenerationRunner:
method __init__ (line 35) | def __init__(
method __call__ (line 49) | def __call__(
method save_generations (line 180) | def save_generations(
FILE: benchmarks/benchmark/gpu_utils.py
function _normalize_gpu_name (line 49) | def _normalize_gpu_name(gpu_name: str) -> str:
function get_gpu_tflops (line 101) | def get_gpu_tflops(gpu_name: str) -> Optional[float]:
function get_gpu_info (line 115) | def get_gpu_info() -> Dict[str, Any]:
FILE: benchmarks/benchmark/tasks/tasks.py
function get_available_benchmark_versions (line 16) | def get_available_benchmark_versions() -> List[str]:
function get_available_task_types (line 21) | def get_available_task_types(benchmark_version: str = LATEST_BENCHMARK_V...
function get_available_domains (line 27) | def get_available_domains(benchmark_version: str = LATEST_BENCHMARK_VERS...
function get_available_languages (line 36) | def get_available_languages(benchmark_version: str = LATEST_BENCHMARK_VE...
function check_benchmark_version (line 46) | def check_benchmark_version(benchmark_version: Optional[str]) -> str:
function check_task_types (line 72) | def check_task_types(
function check_splits (line 105) | def check_splits(
FILE: benchmarks/benchmark/tasks/v1_0/base_evaluator.py
class BaseEval (line 14) | class BaseEval(ABC):
method __init__ (line 17) | def __init__(
method evaluate (line 58) | def evaluate(self) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
method _all_samples_have_keys (line 90) | def _all_samples_have_keys(self, required_keys: List[str]) -> bool:
method required_metrics (line 99) | def required_metrics(self) -> Optional[List[str]]:
method _has_all_required_metrics (line 103) | def _has_all_required_metrics(self) -> bool:
method _compute_metrics_from_scratch (line 109) | def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[...
method _save_debug_json (line 114) | def _save_debug_json(
FILE: benchmarks/benchmark/tasks/v1_0/base_loader.py
class BaseLoader (line 16) | class BaseLoader(ABC):
method __init__ (line 19) | def __init__(
method load_data (line 41) | def load_data(self, split: str = "test", sample_size: Optional[Any] = ...
method _is_empty_value (line 111) | def _is_empty_value(value) -> bool:
method _convert_messages_format (line 134) | def _convert_messages_format(messages: list) -> list:
method _load_custom_chat_template (line 160) | def _load_custom_chat_template(self):
method _get_data_file_path (line 180) | def _get_data_file_path(self, split: str) -> str:
method _get_sample_data_file_path (line 199) | def _get_sample_data_file_path(self, split: str, sample_size: int) -> ...
method _load_dataframe (line 217) | def _load_dataframe(self, split: str) -> pd.DataFrame:
method _sample_data (line 233) | def _sample_data(self, df: pd.DataFrame, sample_size: int) -> pd.DataF...
method _save_sample_data (line 241) | def _save_sample_data(
method _load_sample_dataframe (line 257) | def _load_sample_dataframe(self, split: str, sample_size: int) -> Opti...
method _process_dataframe (line 269) | def _process_dataframe(self, df: pd.DataFrame) -> Dict[str, Dict[str, ...
method _make_metadata_serializable (line 348) | def _make_metadata_serializable(
FILE: benchmarks/benchmark/tasks/v1_0/item_understand/evaluator.py
class ItemUnderstandEvaluator (line 14) | class ItemUnderstandEvaluator(BaseEval):
method required_metrics (line 18) | def required_metrics(self) -> List[str]:
method _compute_metrics_from_scratch (line 23) | def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[...
method _evaluate_wip (line 90) | def _evaluate_wip(
method _save_debug_info (line 168) | def _save_debug_info(
FILE: benchmarks/benchmark/tasks/v1_0/item_understand/utils.py
function extract_json_from_response (line 163) | def extract_json_from_response(response: str) -> Optional[Dict]:
function extract_wips_single (line 178) | def extract_wips_single(
function extract_wips_batch (line 209) | def extract_wips_batch(
function match_wips_single (line 259) | def match_wips_single(
function match_wips_batch (line 292) | def match_wips_batch(
function get_wip_score_int (line 359) | def get_wip_score_int(wip: Optional[Dict]) -> int:
function calculate_unweighted_metrics (line 366) | def calculate_unweighted_metrics(match_results: Dict[str, Dict], core_th...
function calculate_importance_weighted_metrics (line 432) | def calculate_importance_weighted_metrics(
function calculate_double_weighted_metrics (line 504) | def calculate_double_weighted_metrics(
function save_wip_detailed_results (line 608) | def save_wip_detailed_results(
function get_gt_cache_path (line 662) | def get_gt_cache_path(cache_dir: str, model_name: str) -> str:
function load_wip_results_cache (line 667) | def load_wip_results_cache(cache_path: str) -> Optional[Dict[str, Any]]:
function load_gt_wips_cache (line 706) | def load_gt_wips_cache(cache_path: str) -> Optional[Dict[str, List[Dict]]]:
function save_gt_wips_cache (line 736) | def save_gt_wips_cache(gt_wips: Dict[str, List[Dict]], cache_path: str):
function _load_or_extract_gt_wips (line 758) | def _load_or_extract_gt_wips(
function extract_after_think (line 805) | def extract_after_think(text: str) -> str:
function _load_or_extract_model_wips (line 811) | def _load_or_extract_model_wips(
function _load_or_match_wips (line 849) | def _load_or_match_wips(
function _compute_bertscore_incremental (line 892) | def _compute_bertscore_incremental(
function evaluate_wip (line 958) | def evaluate_wip(
FILE: benchmarks/benchmark/tasks/v1_0/label_pred/evaluator.py
class LabelPredEvaluator (line 20) | class LabelPredEvaluator(BaseEval):
method required_metrics (line 32) | def required_metrics(self) -> List[str]:
method _compute_metrics_from_scratch (line 36) | def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[...
method _save_debug_info (line 188) | def _save_debug_info(
FILE: benchmarks/benchmark/tasks/v1_0/label_pred/utils.py
function extract_label_from_answer (line 14) | def extract_label_from_answer(answer: str) -> int:
function extract_probability_from_logprobs (line 38) | def extract_probability_from_logprobs(
function calculate_auc (line 158) | def calculate_auc(
function get_debug_info (line 200) | def get_debug_info(
FILE: benchmarks/benchmark/tasks/v1_0/mfu_evaluator.py
function compute_mfu (line 18) | def compute_mfu(
function compute_mfu_from_generation_data (line 69) | def compute_mfu_from_generation_data(gen_data: Dict[str, Any]) -> Option...
FILE: benchmarks/benchmark/tasks/v1_0/rec_reason/evaluator.py
class RecoReasonEvaluator (line 15) | class RecoReasonEvaluator(BaseEval):
method required_metrics (line 19) | def required_metrics(self) -> List[str]:
method _compute_metrics_from_scratch (line 23) | def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[...
method _evaluate_reasoning (line 91) | def _evaluate_reasoning(
method _save_debug_info (line 157) | def _save_debug_info(
FILE: benchmarks/benchmark/tasks/v1_0/rec_reason/utils.py
function extract_refined_reasoning (line 65) | def extract_refined_reasoning(text: str) -> str:
function extract_after_think (line 96) | def extract_after_think(text: str) -> str:
function extract_json_from_response (line 103) | def extract_json_from_response(response: str) -> Optional[Dict]:
function evaluate_single (line 139) | def evaluate_single(
function evaluate_batch (line 174) | def evaluate_batch(
function calculate_metrics (line 229) | def calculate_metrics(eval_results: Dict[str, Dict]) -> Dict[str, Any]:
function get_per_sample_metrics (line 264) | def get_per_sample_metrics(eval_results: Dict[str, Dict]) -> Dict[str, D...
function get_cache_path (line 290) | def get_cache_path(save_dir: str, model_name: str) -> str:
function load_eval_cache (line 295) | def load_eval_cache(cache_path: str) -> Optional[Dict[str, Dict]]:
function save_eval_results (line 326) | def save_eval_results(
function evaluate_reasoning (line 363) | def evaluate_reasoning(
FILE: benchmarks/benchmark/tasks/v1_0/recommendation/evaluator.py
class RecommendationEvaluator (line 17) | class RecommendationEvaluator(BaseEval):
method required_metrics (line 33) | def required_metrics(self) -> List[str]:
method _select_generations_by_strategy (line 50) | def _select_generations_by_strategy(
method _evaluate_single_mode (line 105) | def _evaluate_single_mode(
method _calculate_metrics_from_counts (line 286) | def _calculate_metrics_from_counts(
method _compute_metrics_from_scratch (line 316) | def _compute_metrics_from_scratch(self) -> Tuple[Dict[str, Any], Dict[...
method _save_debug_info (line 429) | def _save_debug_info(self, debug_info: Dict[str, Any], metrics: Dict[s...
FILE: benchmarks/benchmark/tasks/v1_0/recommendation/utils.py
function extract_ids_from_answer (line 10) | def extract_ids_from_answer(answer: str) -> list[str]:
function extract_first_id_from_answer (line 29) | def extract_first_id_from_answer(answer: str) -> str:
function extract_id_from_generation (line 51) | def extract_id_from_generation(generation: str) -> str:
function compute_pass_at_k (line 94) | def compute_pass_at_k(
function compute_position1_pass_at_k (line 128) | def compute_position1_pass_at_k(
function compute_recall_at_k (line 163) | def compute_recall_at_k(
function get_unique_generations (line 210) | def get_unique_generations(
function get_debug_info (line 297) | def get_debug_info(
FILE: benchmarks/benchmark/tasks/v1_0/recommendation/utils_by_pid.py
function load_pid_mapping (line 21) | def load_pid_mapping(mapping_path: str) -> Dict[int, List[Dict[str, int]]]:
function encode_sid (line 47) | def encode_sid(c1: int, c2: int, c3: int) -> int:
function extract_sid_codes_from_text (line 60) | def extract_sid_codes_from_text(text: str) -> Optional[Tuple[int, int, i...
function _get_id_from_info (line 81) | def _get_id_from_info(info: Dict[str, int]) -> int:
function apply_sid_to_pid_strategy (line 94) | def apply_sid_to_pid_strategy(pid_info_list: List[Dict[str, int]], strat...
function extract_ids_from_answer (line 132) | def extract_ids_from_answer(answer: list[int]) -> list[int]:
function extract_first_id_from_answer (line 149) | def extract_first_id_from_answer(answer: List[int]) -> int:
function extract_id_from_generation (line 161) | def extract_id_from_generation(
function compute_pass_at_k (line 205) | def compute_pass_at_k(
function compute_position1_pass_at_k (line 239) | def compute_position1_pass_at_k(
function compute_recall_at_k (line 274) | def compute_recall_at_k(
function get_unique_generations (line 319) | def get_unique_generations(
function get_debug_info (line 399) | def get_debug_info(
FILE: benchmarks/benchmark/tasks/v1_0/registry.py
class TaskRegistration (line 38) | class TaskRegistration:
function get_loader (line 105) | def get_loader(task_name: str, data_dir: str, tokenizer: Optional[Any] =...
function get_evaluator (line 142) | def get_evaluator(task_name: str):
function get_task_config (line 167) | def get_task_config(task_name: str) -> Dict[str, Any]:
function get_all_tasks (line 190) | def get_all_tasks() -> list:
function get_tasks_by_category (line 200) | def get_tasks_by_category(category: str) -> list:
FILE: benchmarks/scripts/eval_dev_results.py
function get_args (line 6) | def get_args():
function main (line 33) | def main():
FILE: benchmarks/scripts/ray-vllm/evaluate.py
function main (line 17) | def main():
FILE: benchmarks/scripts/ray-vllm/utils/arguments.py
class ModelConfig (line 6) | class ModelConfig:
class InfrastructureConfig (line 34) | class InfrastructureConfig:
class InferenceConfig (line 66) | class InferenceConfig:
class GenerationConfig (line 85) | class GenerationConfig:
class PromptConfig (line 121) | class PromptConfig:
class BenchmarkConfig (line 131) | class BenchmarkConfig:
FILE: benchmarks/scripts/ray-vllm/utils/generator.py
class VllmWorker (line 14) | class VllmWorker:
method __init__ (line 24) | def __init__(
method get_model_parameters (line 85) | def get_model_parameters(self) -> Optional[float]:
method generate_batch (line 106) | def generate_batch(
method extract_token_logprobs_batch (line 270) | def extract_token_logprobs_batch(
class RayVllmGenerator (line 392) | class RayVllmGenerator(RayMixin, VllmMixin, Generator):
method __init__ (line 397) | def __init__(
method _count_model_parameters (line 611) | def _count_model_parameters(self) -> Optional[float]:
method _generate_standard (line 646) | def _generate_standard(
method extract_token_logprobs (line 766) | def extract_token_logprobs(
FILE: data/onerec_data/pretrain/item_understand.py
function pid_to_sid (line 32) | def pid_to_sid(pid, pid2sid: dict) -> str:
function build_segments (line 40) | def build_segments(sid: str, caption: str) -> str:
function process_row (line 48) | def process_row(row, pid2sid: dict) -> dict:
function main (line 73) | def main():
FILE: data/onerec_data/pretrain/user_profile.py
function process_row (line 17) | def process_row(row) -> dict:
function main (line 37) | def main():
FILE: data/onerec_data/pretrain/video_rec.py
function pids_to_sids (line 23) | def pids_to_sids(pids, pid2sid: dict) -> str:
function build_segments (line 36) | def build_segments(hist_sids: str, target_sids: str) -> str:
function process_row (line 43) | def process_row(row, pid2sid: dict) -> dict:
function main (line 70) | def main():
FILE: data/onerec_data/sft/ad_rec.py
function pids_to_sids (line 59) | def pids_to_sids(pids, pid2sid: dict) -> str:
function build_messages (line 72) | def build_messages(user_content: str, task_prompt: str, answer: str) -> ...
function process_row (line 84) | def process_row(row, pid2sid: dict) -> dict:
function main (line 139) | def main():
FILE: data/onerec_data/sft/interactive_rec.py
function pids_to_sids (line 52) | def pids_to_sids(pids, pid2sid: dict) -> str:
function build_messages (line 65) | def build_messages(user_profile: str, keyword: str, answer: str) -> str:
function process_row (line 78) | def process_row(row, pid2sid: dict) -> list:
function main (line 125) | def main():
FILE: data/onerec_data/sft/item_understand.py
function pid_to_sid (line 58) | def pid_to_sid(pid, pid2sid: dict) -> str:
function build_messages (line 66) | def build_messages(sid: str, caption: str) -> str:
function process_row (line 79) | def process_row(row, pid2sid: dict) -> dict:
function main (line 104) | def main():
FILE: data/onerec_data/sft/label_cond_rec.py
function pids_to_sids (line 54) | def pids_to_sids(pids, pid2sid: dict) -> str:
function build_messages (line 67) | def build_messages(user_content: str, task_prompt: str, answer: str) -> ...
function process_row (line 79) | def process_row(row, pid2sid: dict) -> dict:
function main (line 163) | def main():
FILE: data/onerec_data/sft/label_pred.py
function pids_to_sids (line 59) | def pids_to_sids(pids, pid2sid: dict) -> str:
function pid_to_sid (line 72) | def pid_to_sid(pid, pid2sid: dict) -> str:
function build_messages (line 80) | def build_messages(user_content: str, question: str, answer: str) -> str:
function process_row (line 92) | def process_row(row, pid2sid: dict) -> list:
function main (line 182) | def main():
FILE: data/onerec_data/sft/product_rec.py
function pids_to_sids (line 60) | def pids_to_sids(pids, pid2sid: dict) -> str:
function build_messages (line 73) | def build_messages(user_content: str, task_prompt: str, answer: str) -> ...
function process_row (line 85) | def process_row(row, video_pid2sid: dict, product_pid2sid: dict) -> dict:
function main (line 140) | def main():
FILE: data/onerec_data/sft/rec_reason.py
function build_messages (line 33) | def build_messages(user_prompt: str, answer: str) -> str:
function is_valid_str (line 42) | def is_valid_str(val) -> bool:
function process_row (line 53) | def process_row(row) -> dict:
function main (line 91) | def main():
FILE: data/onerec_data/sft/video_rec.py
function pids_to_sids (line 41) | def pids_to_sids(pids, pid2sid: dict) -> str:
function build_messages (line 54) | def build_messages(query: str, answer: str) -> str:
function process_row (line 67) | def process_row(row, pid2sid: dict) -> dict:
function main (line 94) | def main():
FILE: data/scripts/parquet_unicode_fix.py
function decode_unicode_json (line 27) | def decode_unicode_json(json_str: Optional[Union[str, bytes]]) -> Option...
function find_parquet_files (line 67) | def find_parquet_files(directory: str, recursive: bool = True) -> List[s...
function get_output_path (line 93) | def get_output_path(input_path: str, output_base: str, input_base: Optio...
function process_parquet_file (line 128) | def process_parquet_file(
function process_directory (line 185) | def process_directory(input_dir: str, output_dir: str, engine: str = 'py...
function main (line 241) | def main():
FILE: data/scripts/sample_data.py
function find_parquet_files (line 27) | def find_parquet_files(directory: str, recursive: bool = True) -> List[s...
function collect_parquet_files (line 50) | def collect_parquet_files(input_paths: List[str], recursive: bool = True...
function load_all_parquet_files (line 83) | def load_all_parquet_files(file_paths: List[str], engine: str = 'pyarrow...
function sample_dataframe (line 121) | def sample_dataframe(df: pd.DataFrame, num_samples: int, seed: int = Non...
function main (line 160) | def main():
FILE: data/scripts/split_data.py
function find_parquet_files (line 26) | def find_parquet_files(directory: str, recursive: bool = True) -> List[s...
function load_all_parquet_files (line 49) | def load_all_parquet_files(file_paths: List[str], engine: str = 'pyarrow...
function split_dataframe (line 87) | def split_dataframe(df: pd.DataFrame, max_rows: int, output_dir: str, pr...
function main (line 147) | def main():
FILE: data/scripts/train_test_split.py
function load_all_parquet_files (line 26) | def load_all_parquet_files(file_paths: List[str], engine: str = 'pyarrow...
function split_train_test (line 64) | def split_train_test(
function shuffle_dataframe (line 113) | def shuffle_dataframe(df: pd.DataFrame, seed: int = None) -> pd.DataFrame:
function main (line 133) | def main():
FILE: pretrain/onerec_llm/data/dataloaders.py
function get_chat_completion_parquet_dataloader (line 5) | def get_chat_completion_parquet_dataloader(sources: str,
function get_dataloader (line 43) | def get_dataloader(name: str, **kwargs):
FILE: pretrain/onerec_llm/data/local_shuffle_buffer.py
class LocalShuffleBuffer (line 19) | class LocalShuffleBuffer:
method __init__ (line 35) | def __init__(self, buffer_size: int = 2048, random_fetch: float = 0.01...
method _calc_sample_hash (line 57) | def _calc_sample_hash(self, obj: dict, buffer_epoch: int = None) -> int:
method add (line 85) | def add(self, obj: dict, fn: str = None, epoch: int = None) -> bool:
method get (line 136) | def get(self) -> dict:
method __len__ (line 154) | def __len__(self) -> int:
FILE: pretrain/onerec_llm/data/qwen3_dataset.py
function set_kwargs (line 36) | def set_kwargs(self, kwargs, **_kwargs):
class Qwen3ChatCompletionDataset (line 42) | class Qwen3ChatCompletionDataset(IterableDataset):
method __init__ (line 43) | def __init__(self, **kwargs):
method _build_source_dataset (line 92) | def _build_source_dataset(self, sources):
method _convert_messages (line 141) | def _convert_messages(self, messages):
method _get_assistant_mask (line 203) | def _get_assistant_mask(self, batch_input_ids: torch.Tensor,
method _get_rope_index_qwen3 (line 257) | def _get_rope_index_qwen3(
method _process_completion (line 265) | def _process_completion(self, sample: Dict[str, Any]) -> Dict[str, tor...
method _process_chat (line 321) | def _process_chat(self, sample: Dict[str, Any]) -> Dict[str, torch.Ten...
method _process (line 379) | def _process(self, sample, source_name=None):
method _cut_sample (line 398) | def _cut_sample(self, inputs, packable_length):
method _append_sample_packing (line 406) | def _append_sample_packing(self,
method _packing (line 436) | def _packing(self, buffer: List[Dict[str, torch.Tensor]]):
method __iter__ (line 490) | def __iter__(self):
class Qwen3NaiveParquetDataset (line 559) | class Qwen3NaiveParquetDataset(IterableDataset):
method __init__ (line 562) | def __init__(self, data_files, num_workers, **kwargs):
method _parser (line 574) | def _parser(self, raw_row_data, file_url):
method __iter__local_shuffle (line 618) | def __iter__local_shuffle(self):
method __iter__ (line 670) | def __iter__(self,):
method state_dict (line 675) | def state_dict(self):
method load_state_dict (line 684) | def load_state_dict(self, state_dict):
class Qwen3ChatCompletionParquetDataset (line 698) | class Qwen3ChatCompletionParquetDataset(Qwen3ChatCompletionDataset):
method __init__ (line 699) | def __init__(self, sources, num_workers, shuffle_seed=1024, num_epochs...
method _build_source_dataset (line 709) | def _build_source_dataset(self, sources):
method state_dict (line 739) | def state_dict(self):
method load_state_dict (line 744) | def load_state_dict(self, state_dict):
FILE: pretrain/onerec_llm/losses/ce.py
class CrossEntropyLoss (line 10) | class CrossEntropyLoss(nn.Module):
method __init__ (line 16) | def __init__(self,
method forward (line 27) | def forward(self, logits: torch.Tensor, labels: torch.Tensor):
class ChunkedLossComputer (line 72) | class ChunkedLossComputer:
method __init__ (line 82) | def __init__(self, lm_head: nn.Module, loss_fn: nn.Module, minibatch_s...
method forward_and_backward (line 102) | def forward_and_backward(self, input: torch.Tensor, labels: torch.Tens...
FILE: pretrain/onerec_llm/models/qwen3/configuration_qwen3.py
class Qwen3Config (line 25) | class Qwen3Config(PretrainedConfig):
method __init__ (line 152) | def __init__(
FILE: pretrain/onerec_llm/models/qwen3/modeling_qwen3.py
class Qwen3RMSNorm (line 55) | class Qwen3RMSNorm(nn.Module):
method __init__ (line 56) | def __init__(self, hidden_size, eps=1e-6):
method forward (line 64) | def forward(self, hidden_states):
method extra_repr (line 71) | def extra_repr(self):
class Qwen3MLP (line 75) | class Qwen3MLP(nn.Module):
method __init__ (line 76) | def __init__(self, config):
method forward (line 86) | def forward(self, x):
function rotate_half (line 91) | def rotate_half(x):
function apply_rotary_pos_emb (line 98) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_di...
function repeat_kv (line 125) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
function eager_attention_forward (line 137) | def eager_attention_forward(
class Qwen3Attention (line 163) | class Qwen3Attention(nn.Module):
method __init__ (line 166) | def __init__(self, config: Qwen3Config, layer_idx: int):
method forward (line 198) | def forward(
class Qwen3DecoderLayer (line 253) | class Qwen3DecoderLayer(nn.Module):
method __init__ (line 254) | def __init__(self, config: Qwen3Config, layer_idx: int):
method forward (line 269) | def forward(
class Qwen3RotaryEmbedding (line 312) | class Qwen3RotaryEmbedding(nn.Module):
method __init__ (line 313) | def __init__(self, config: Qwen3Config, device=None):
method forward (line 332) | def forward(self, x, position_ids):
class Qwen3PreTrainedModel (line 367) | class Qwen3PreTrainedModel(PreTrainedModel):
method _init_weights (line 381) | def _init_weights(self, module):
class Qwen3Model (line 464) | class Qwen3Model(Qwen3PreTrainedModel):
method __init__ (line 472) | def __init__(self, config: Qwen3Config):
method get_input_embeddings (line 488) | def get_input_embeddings(self):
method set_input_embeddings (line 491) | def set_input_embeddings(self, value):
method forward (line 496) | def forward(
method _update_causal_mask (line 613) | def _update_causal_mask(
method _prepare_4d_causal_attention_mask_with_cache_position (line 697) | def _prepare_4d_causal_attention_mask_with_cache_position(
class KwargsForCausalLM (line 766) | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Qwen3ForCausalLM (line 769) | class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
method __init__ (line 776) | def __init__(self, config):
method get_input_embeddings (line 785) | def get_input_embeddings(self):
method set_input_embeddings (line 788) | def set_input_embeddings(self, value):
method get_output_embeddings (line 791) | def get_output_embeddings(self):
method set_output_embeddings (line 794) | def set_output_embeddings(self, new_embeddings):
method set_decoder (line 797) | def set_decoder(self, decoder):
method get_decoder (line 800) | def get_decoder(self):
method forward (line 807) | def forward(
class Qwen3ForSequenceClassification (line 909) | class Qwen3ForSequenceClassification(Qwen3PreTrainedModel):
method __init__ (line 910) | def __init__(self, config):
method get_input_embeddings (line 919) | def get_input_embeddings(self):
method set_input_embeddings (line 922) | def set_input_embeddings(self, value):
method forward (line 927) | def forward(
class Qwen3ForTokenClassification (line 1002) | class Qwen3ForTokenClassification(Qwen3PreTrainedModel):
method __init__ (line 1003) | def __init__(self, config):
method get_input_embeddings (line 1019) | def get_input_embeddings(self):
method set_input_embeddings (line 1022) | def set_input_embeddings(self, value):
method forward (line 1032) | def forward(
class Qwen3ForQuestionAnswering (line 1084) | class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):
method __init__ (line 1087) | def __init__(self, config):
method get_input_embeddings (line 1095) | def get_input_embeddings(self):
method set_input_embeddings (line 1098) | def set_input_embeddings(self, value):
method forward (line 1103) | def forward(
FILE: pretrain/onerec_llm/models/qwen3/modular_qwen3.py
class Qwen3RMSNorm (line 51) | class Qwen3RMSNorm(LlamaRMSNorm):
class Qwen3MLP (line 55) | class Qwen3MLP(GemmaMLP):
class Qwen3Attention (line 59) | class Qwen3Attention(LlamaAttention):
method __init__ (line 60) | def __init__(self, config: Qwen3Config, layer_idx: int):
method forward (line 72) | def forward(
class Qwen3DecoderLayer (line 123) | class Qwen3DecoderLayer(LlamaDecoderLayer):
method __init__ (line 124) | def __init__(self, config: Qwen3Config, layer_idx: int):
class Qwen3Model (line 137) | class Qwen3Model(MistralModel): # mistral model creates sliding window
class KwargsForCausalLM (line 141) | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Qwen3ForCausalLM (line 144) | class Qwen3ForCausalLM(LlamaForCausalLM):
method forward (line 145) | def forward(
class Qwen3ForSequenceClassification (line 183) | class Qwen3ForSequenceClassification(LlamaForSequenceClassification):
class Qwen3ForTokenClassification (line 187) | class Qwen3ForTokenClassification(LlamaForTokenClassification):
class Qwen3ForQuestionAnswering (line 191) | class Qwen3ForQuestionAnswering(LlamaForQuestionAnswering):
FILE: pretrain/onerec_llm/training/activations.py
function set_activation_checkpointing (line 8) | def set_activation_checkpointing(
FILE: pretrain/onerec_llm/training/checkpoint.py
function load_safetensors (line 32) | def load_safetensors(path: Union[Path, str]) -> Dict[str, torch.Tensor]:
function safe_torch_load (line 48) | def safe_torch_load(
function load_hf_checkpoint (line 83) | def load_hf_checkpoint(
function load_checkpoint_to_state_dict (line 141) | def load_checkpoint_to_state_dict(checkpoint_path: Union[str, os.PathLik...
class CheckpointerInterface (line 232) | class CheckpointerInterface(Protocol):
method load_checkpoint (line 235) | def load_checkpoint(self, **kwargs) -> Dict[str, Any]:
method save_checkpoint (line 239) | def save_checkpoint(self, state_dict: Dict[str, Any], **kwargs) -> None:
class DistributedCheckpointer (line 243) | class DistributedCheckpointer(CheckpointerInterface):
method __init__ (line 253) | def __init__(
method get_latest_checkpoint (line 262) | def get_latest_checkpoint(self, checkpoint_dir: str) -> Optional[str]:
method load_checkpoint (line 291) | def load_checkpoint(
method save_checkpoint (line 337) | def save_checkpoint(
class AppState (line 432) | class AppState(Stateful):
method __init__ (line 442) | def __init__(self, model, optimizer=None, call_back=None):
method set_call_back (line 446) | def set_call_back(self, cb):
method state_dict (line 450) | def state_dict(self):
method load_state_dict (line 461) | def load_state_dict(self, state_dict):
FILE: pretrain/onerec_llm/training/common.py
function set_default_dtype (line 10) | def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
FILE: pretrain/onerec_llm/training/distributed.py
function shard_model (line 15) | def shard_model(
function load_from_full_model_state_dict (line 65) | def load_from_full_model_state_dict(
FILE: pretrain/onerec_llm/training/gradients.py
function clip_grad_by_value (line 15) | def clip_grad_by_value(
function clip_grad_norm (line 29) | def clip_grad_norm(
function compute_fsdp_zero2_grad_norm (line 43) | def compute_fsdp_zero2_grad_norm(
class EmbeddingGradientMasker (line 76) | class EmbeddingGradientMasker:
method __init__ (line 91) | def __init__(self, model, config, start_optimize_embedding_index):
method _find_embedding_parameters (line 102) | def _find_embedding_parameters(self):
method _save_initial_weights (line 108) | def _save_initial_weights(self):
method save_frozen_params (line 143) | def save_frozen_params(self):
method apply_gradient_mask (line 147) | def apply_gradient_mask(self, optimizer=None):
method restore_frozen_params (line 151) | def restore_frozen_params(self):
FILE: pretrain/onerec_llm/training/lr_schedulers.py
function _get_cosine_schedule_with_warmup_lr_lambda (line 11) | def _get_cosine_schedule_with_warmup_lr_lambda(
function get_cosine_scheduler (line 47) | def get_cosine_scheduler(
function get_scheduler (line 91) | def get_scheduler(
FILE: pretrain/onerec_llm/utils/common.py
function print_rank_n (line 20) | def print_rank_n(*msg, rank=0):
function print_rank_0 (line 28) | def print_rank_0(*msg):
function get_optimizer_grouped_parameters (line 31) | def get_optimizer_grouped_parameters(model,
function to_device (line 71) | def to_device(batch, device, non_blocking=True):
function to_cuda (line 77) | def to_cuda(batch, non_blocking=True):
function set_random_seed (line 81) | def set_random_seed(seed):
function dist_reduce_dict (line 90) | def dist_reduce_dict(local_dict, group=None):
class Timer (line 117) | class Timer:
method __init__ (line 118) | def __init__(self, desc: str = ""):
method __enter__ (line 121) | def __enter__(self):
method __exit__ (line 126) | def __exit__(self, exc_type, exc_value, traceback):
FILE: pretrain/onerec_llm/utils/data_utils.py
function calculate_text_hash (line 17) | def calculate_text_hash(text):
function shell_hdfs_ls (line 31) | def shell_hdfs_ls(source_dir):
class FakeParquetFileFromFastParquetFile (line 55) | class FakeParquetFileFromFastParquetFile:
method __init__ (line 58) | def __init__(self, fast_parquet_file):
method read_row_group (line 68) | def read_row_group(self, i):
function load_parquet_file (line 73) | def load_parquet_file(
function _load_parquet_from_hdfs (line 125) | def _load_parquet_from_hdfs(
function _load_parquet_from_path (line 193) | def _load_parquet_from_path(file_path: str, parquet_backend: str) -> pq....
function _clean_cache_if_needed (line 201) | def _clean_cache_if_needed(cache_dir: str, max_cache_files: int):
function _download_from_hdfs (line 227) | def _download_from_hdfs(hdfs_path: str, local_path: str, hadoop_cmd: str):
FILE: pretrain/onerec_llm/utils/distributed.py
function get_world_size_and_rank (line 15) | def get_world_size_and_rank() -> Tuple[int, int]:
function get_rank (line 34) | def get_rank() -> int:
function get_world_size (line 44) | def get_world_size() -> int:
function is_distributed (line 54) | def is_distributed() -> bool:
FILE: pretrain/onerec_llm/utils/ds_utils.py
function convert_dataclass_to_dict (line 12) | def convert_dataclass_to_dict(obj: Any) -> Any:
function tensor_statistics (line 19) | def tensor_statistics(tensor: torch.Tensor, n: int = -1, **kwargs) -> Tu...
function print_input_info (line 105) | def print_input_info(
function format_dict_or_list (line 230) | def format_dict_or_list(obj: Any, indent_level: int = 0, indent_size: in...
FILE: pretrain/onerec_llm/utils/mfu_stats.py
function _sum_if_list (line 20) | def _sum_if_list(x: Union[int, List[int]]) -> int:
function _get_gpu_model (line 26) | def _get_gpu_model() -> str:
function _is_h800 (line 99) | def _is_h800() -> bool:
function _get_gpu_flops (line 106) | def _get_gpu_flops() -> float:
function _calculate_decoder_layer_flops (line 115) | def _calculate_decoder_layer_flops(
function _calculate_decoder_layers_flops (line 208) | def _calculate_decoder_layers_flops(
function _calculate_llm_flops (line 269) | def _calculate_llm_flops(llm_params: easydict.EasyDict) -> Dict:
function _extract_model_params (line 306) | def _extract_model_params(config_path: str) -> easydict.EasyDict:
function _calc_mfu (line 342) | def _calc_mfu(
class MFUStats (line 394) | class MFUStats:
method __init__ (line 403) | def __init__(self, args):
method set (line 409) | def set(self, num_tokens: int, num_samples: int) -> None:
method mfu (line 419) | def mfu(self, secs: float, global_step: int) -> Dict[str, float]:
FILE: pretrain/onerec_llm/utils/time_tracker.py
class TimeTracker (line 8) | class TimeTracker:
method __init__ (line 28) | def __init__(
method tick (line 44) | def tick(self, name: str) -> None:
method stat (line 82) | def stat(self) -> Dict[str, float]:
FILE: pretrain/onerec_llm/utils/worker_utils.py
function get_worker_info (line 8) | def get_worker_info():
function pytorch_worker_info (line 36) | def pytorch_worker_info(group=None):
FILE: pretrain/recipes/train_qwen3.py
class TrainingMetrics (line 83) | class TrainingMetrics:
method __init__ (line 91) | def __init__(self):
method reset_period_accumulators (line 100) | def reset_period_accumulators(self):
method update (line 115) | def update(self, num_tokens, num_samples, num_valid_tokens):
class TensorBoardLogger (line 128) | class TensorBoardLogger:
method __init__ (line 131) | def __init__(self, tb_writer: Optional[SummaryWriter]):
method _write_async (line 144) | def _write_async(self, tb_writer, metrics_queue):
method log (line 202) | def log(self, global_step, log_dict, ticker_stats, ds_loss, ds_tokens,...
function get_argument_parser (line 210) | def get_argument_parser() -> argparse.ArgumentParser:
class StateDictConverter (line 297) | class StateDictConverter:
method convert (line 300) | def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, to...
method revert (line 304) | def revert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor...
function _init_profiler (line 309) | def _init_profiler(output_dir: str, enable: bool = False) -> Optional[to...
function save_model_checkpoint (line 343) | def save_model_checkpoint(
function initialize_distributed (line 421) | def initialize_distributed() -> Tuple[int, int, int]:
function initialize_model (line 441) | def initialize_model(
function load_model_checkpoint (line 552) | def load_model_checkpoint(
function load_optimizer_checkpoint (line 579) | def load_optimizer_checkpoint(
function load_dataloader_checkpoint (line 603) | def load_dataloader_checkpoint(args) -> Optional[Dict]:
function load_checkpoint (line 629) | def load_checkpoint(
function compute_forward_backward (line 683) | def compute_forward_backward(
function compute_metrics (line 760) | def compute_metrics(
function log_training_step (line 861) | def log_training_step(
function train (line 1014) | def train():
FILE: pretrain/tests/test_qwen3_dataset_file_distribution.py
class TestFileDistribution (line 16) | class TestFileDistribution(unittest.TestCase):
method setUp (line 19) | def setUp(self):
method _get_file_distribution (line 27) | def _get_file_distribution(self, rank, world_size, worker, num_workers):
method test_file_distribution_no_overlap (line 48) | def test_file_distribution_no_overlap(self):
method test_file_distribution_completeness (line 77) | def test_file_distribution_completeness(self):
method test_file_distribution_different_configs (line 96) | def test_file_distribution_different_configs(self):
method test_file_distribution_balance (line 135) | def test_file_distribution_balance(self):
method test_file_distribution_with_epochs (line 163) | def test_file_distribution_with_epochs(self):
class TestFileDistributionLogic (line 197) | class TestFileDistributionLogic(unittest.TestCase):
method setUp (line 200) | def setUp(self):
method test_distribution_algorithm (line 206) | def test_distribution_algorithm(self):
function run_distribution_test_manual (line 238) | def run_distribution_test_manual():
FILE: pretrain/tools/model_converter/convert_checkpoint_to_hf.py
function _get_torch_dtype (line 53) | def _get_torch_dtype(dtype_str: str) -> torch.dtype:
function _extract_state_dict_from_checkpoint (line 75) | def _extract_state_dict_from_checkpoint(checkpoint: Dict, model_only: bo...
function _convert_state_dict_to_shards (line 104) | def _convert_state_dict_to_shards(
function pth_to_hf_format (line 199) | def pth_to_hf_format(
function dcp_to_hf_format (line 249) | def dcp_to_hf_format(
function copy_hf_config_files (line 309) | def copy_hf_config_files(
function get_argument_parser (line 364) | def get_argument_parser() -> argparse.ArgumentParser:
function main (line 427) | def main() -> None:
FILE: pretrain/tools/model_converter/expand_qwen3_vocab.py
function _align_vocab_size (line 28) | def _align_vocab_size(vocab_size: int, alignment: int = 256) -> int:
function _fix_chat_template (line 41) | def _fix_chat_template(reco_model_dir: str, hf_model_dir: str) -> None:
function _test_expanded_vocab (line 76) | def _test_expanded_vocab(model, tokenizer, new_tokens: List[str]) -> None:
function expand_qwen3_vocab_for_pretraining (line 111) | def expand_qwen3_vocab_for_pretraining(
function generate_itemic_tokens (line 199) | def generate_itemic_tokens(itemic_layer_n: int, vocab_size_per_layer: in...
function load_tokens_from_file (line 251) | def load_tokens_from_file(tokens_file: str) -> List[str]:
function main (line 280) | def main():
FILE: pretrain/tools/model_test/test_hf_model.py
function load_model (line 27) | def load_model(
function print_model_info (line 61) | def print_model_info(model) -> None:
function generate_text (line 81) | def generate_text(
function generate_chat (line 135) | def generate_chat(
function load_test_cases_from_file (line 210) | def load_test_cases_from_file(file_path: Union[str, Path]) -> tuple:
function get_default_test_cases (line 246) | def get_default_test_cases() -> tuple:
function main (line 279) | def main():
FILE: tokenizer/infer_res_kmeans.py
function load_embeddings (line 8) | def load_embeddings(emb_path):
function main (line 16) | def main():
FILE: tokenizer/res_kmeans.py
class ResKmeans (line 4) | class ResKmeans(nn.Module):
method __init__ (line 6) | def __init__(self, n_layers, codebook_size, dim, extra_kmeans_config=N...
method calc_loss (line 17) | def calc_loss(self, x, out, epsilon=1e-4):
method train_kmeans (line 22) | def train_kmeans(self, inputs, verbose=True):
method encode (line 40) | def encode(self, x, n_layers=None):
method decode (line 56) | def decode(self, code):
FILE: tokenizer/train_res_kmeans.py
function read_train_data (line 11) | def read_train_data(path, emb_dim):
function main (line 40) | def main():
FILE: verl_distillation/docs/_static/js/resizable-sidebar.js
function setupNavigationFix (line 136) | function setupNavigationFix() {
FILE: verl_distillation/examples/data_preprocess/aime2024_multiturn_w_tool.py
function make_map_fn (line 49) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/dapo_multiturn_w_tool.py
function make_map_fn (line 49) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/full_hh_rlhf.py
function generate_sft_dataset (line 30) | def generate_sft_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh...
function generate_rm_dataset (line 61) | def generate_rm_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_...
function generate_rl_dataset (line 93) | def generate_rl_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_...
FILE: verl_distillation/examples/data_preprocess/geo3k.py
function make_map_fn (line 58) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/geo3k_multiturn_w_tool.py
function make_map_fn (line 60) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/gsm8k.py
function extract_solution (line 27) | def extract_solution(solution_str):
function make_map_fn (line 60) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/gsm8k_multiturn_sft.py
function extract_solution (line 27) | def extract_solution(solution_str):
function make_map_fn (line 60) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/gsm8k_multiturn_w_interaction.py
function extract_solution (line 29) | def extract_solution(solution_str):
function make_map_fn (line 62) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/gsm8k_multiturn_w_tool.py
function extract_solution (line 29) | def extract_solution(solution_str):
function make_map_fn (line 62) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/gsm8k_tool_agent_loop.py
function extract_solution (line 29) | def extract_solution(solution_str):
function make_map_fn (line 62) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/hellaswag.py
function preprocess (line 28) | def preprocess(text):
function make_map_fn (line 62) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/math_dataset.py
function extract_solution (line 28) | def extract_solution(solution_str):
function make_map_fn (line 63) | def make_map_fn(split):
FILE: verl_distillation/examples/data_preprocess/multiturn.py
function main (line 24) | def main():
FILE: verl_distillation/examples/data_preprocess/preprocess_search_r1_dataset.py
function process_single_row (line 45) | def process_single_row(row, current_split_name, row_index):
function main (line 101) | def main():
FILE: verl_distillation/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py
function load_corpus (line 34) | def load_corpus(corpus_path: str):
function load_docs (line 39) | def load_docs(corpus, doc_idxs):
function load_model (line 44) | def load_model(model_path: str, use_fp16: bool = False):
function pooling (line 54) | def pooling(pooler_output, last_hidden_state, attention_mask=None, pooli...
class Encoder (line 66) | class Encoder:
method __init__ (line 67) | def __init__(self, model_name, model_path, pooling_method, max_length,...
method encode (line 78) | def encode(self, query_list: list[str], is_query=True) -> np.ndarray:
class BaseRetriever (line 124) | class BaseRetriever:
method __init__ (line 125) | def __init__(self, config):
method _search (line 133) | def _search(self, query: str, num: int, return_score: bool):
method _batch_search (line 136) | def _batch_search(self, query_list: list[str], num: int, return_score:...
method search (line 139) | def search(self, query: str, num: int = None, return_score: bool = Fal...
method batch_search (line 142) | def batch_search(self, query_list: list[str], num: int = None, return_...
class BM25Retriever (line 146) | class BM25Retriever(BaseRetriever):
method __init__ (line 147) | def __init__(self, config):
method _check_contain_doc (line 157) | def _check_contain_doc(self):
method _search (line 160) | def _search(self, query: str, num: int = None, return_score: bool = Fa...
method _batch_search (line 193) | def _batch_search(self, query_list: list[str], num: int = None, return...
class DenseRetriever (line 206) | class DenseRetriever(BaseRetriever):
method __init__ (line 207) | def __init__(self, config):
method _search (line 227) | def _search(self, query: str, num: int = None, return_score: bool = Fa...
method _batch_search (line 240) | def _batch_search(self, query_list: list[str], num: int = None, return...
function get_retriever (line 273) | def get_retriever(config):
class Config (line 285) | class Config:
method __init__ (line 291) | def __init__(
class QueryRequest (line 320) | class QueryRequest(BaseModel):
function retrieve_endpoint (line 330) | def retrieve_endpoint(request: QueryRequest):
FILE: verl_distillation/examples/split_placement/main_ppo_split.py
function _select_rm_score_fn (line 29) | def _select_rm_score_fn(data_source):
class RewardManager (line 38) | class RewardManager:
method __init__ (line 39) | def __init__(self, tokenizer, num_examine) -> None:
method __call__ (line 43) | def __call__(self, data: DataProto, return_dict: bool = False):
function main (line 95) | def main(config):
function main_task (line 110) | def main_task(config):
FILE: verl_distillation/examples/split_placement/split_monkey_patch.py
function fit (line 38) | def fit(self):
FILE: verl_distillation/examples/tutorial/agent_loop_get_started/sandbox.py
class SandboxTool (line 22) | class SandboxTool(BaseTool):
method __init__ (line 23) | def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
method code_interpreter (line 28) | async def code_interpreter(self, code: str) -> str:
method get_openai_tool_schema (line 47) | def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
method execute (line 51) | async def execute(self, instance_id: str, parameters: dict, **kwargs) ...
FILE: verl_distillation/recipe/char_count/create_dataset.py
function generate_random_char (line 30) | def generate_random_char():
function create_prompt_response (line 34) | def create_prompt_response(min_length=3, max_length=5):
FILE: verl_distillation/recipe/char_count/reward_function.py
function char_count_reward_function (line 22) | def char_count_reward_function(data_source, solution_str, ground_truth, ...
FILE: verl_distillation/recipe/collabllm/collabllm_agent_loop.py
class CollabLLMAgentLoop (line 32) | class CollabLLMAgentLoop(ToolAgentLoop):
method run (line 34) | async def run(self, sampling_params: dict[str, Any], **kwargs) -> Agen...
method run_agent_data_loop (line 117) | async def run_agent_data_loop(self, agent_data: AgentData, sampling_pa...
FILE: verl_distillation/recipe/collabllm/collabllm_interation.py
class CollabLLMInteraction (line 82) | class CollabLLMInteraction(BaseInteraction):
method __init__ (line 91) | def __init__(self, config: dict):
method start_interaction (line 107) | async def start_interaction(
method generate_response (line 122) | async def generate_response(
method finalize_interaction (line 190) | async def finalize_interaction(self, instance_id: str, **kwargs) -> None:
method _parse_messages (line 193) | def _parse_messages(self, messages, strip_sys_prompt=True):
function extract_json (line 207) | def extract_json(s):
FILE: verl_distillation/recipe/collabllm/metrics/accuracy.py
function compute_score (line 53) | async def compute_score(data_source, messages, ground_truth, extra_info,...
FILE: verl_distillation/recipe/collabllm/metrics/bleu_score.py
function compute_score (line 67) | async def compute_score(data_source, messages, ground_truth, extra_info,...
FILE: verl_distillation/recipe/collabllm/metrics/interactivity.py
function compute_score (line 61) | async def compute_score(data_source, messages, ground_truth, extra_info,...
FILE: verl_distillation/recipe/collabllm/metrics/pass_rate.py
function compute_score (line 73) | async def compute_score(data_source, messages, ground_truth, extra_info,...
FILE: verl_distillation/recipe/collabllm/metrics/token_amount.py
function compute_score (line 17) | def compute_score(data_source, messages, ground_truth, extra_info, **kwa...
FILE: verl_distillation/recipe/collabllm/process_dataset.py
function collapse_example (line 83) | def collapse_example(example: dict[str, Any]) -> dict[str, Any]:
function save_parquet (line 125) | def save_parquet(ds_split: Dataset, filename: str, out_dir: str) -> None:
function maybe_copy_to_hdfs (line 132) | def maybe_copy_to_hdfs(local_dir: str, hdfs_dir: Optional[str]) -> None:
function main (line 146) | def main():
FILE: verl_distillation/recipe/collabllm/reward_function.py
function conversation_level_reward_func (line 34) | async def conversation_level_reward_func(
class CollabLLMRewardManager (line 108) | class CollabLLMRewardManager(AbstractRewardManager):
method __init__ (line 113) | def __init__(
method __call__ (line 134) | def __call__(self, data: DataProto, return_dict: bool = False) -> torc...
method _compute_rewards_async (line 149) | async def _compute_rewards_async(self, data: DataProto, return_dict: b...
FILE: verl_distillation/recipe/collabllm/utils.py
function parse_messages (line 23) | def parse_messages(messages, strip_sys_prompt=True):
function strip_system_prompt (line 42) | def strip_system_prompt(messages):
function extract_json (line 53) | def extract_json(s):
function remove_think_block (line 222) | def remove_think_block(msg: dict):
function is_valid_messages (line 231) | def is_valid_messages(msg: dict) -> bool:
FILE: verl_distillation/recipe/dapo/dapo_ray_trainer.py
class RayDAPOTrainer (line 45) | class RayDAPOTrainer(RayPPOTrainer):
method compute_kl_related_metrics (line 50) | def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, ...
method fit (line 76) | def fit(self):
FILE: verl_distillation/recipe/dapo/main_dapo.py
function main (line 32) | def main(config):
function run_ppo (line 36) | def run_ppo(config) -> None:
class TaskRunner (line 69) | class TaskRunner:
method run (line 70) | def run(self, config):
FILE: verl_distillation/recipe/deepeyes/deepeyes.py
class CustomRLHFDataset (line 52) | class CustomRLHFDataset(RLHFDataset):
method __getitem__ (line 53) | def __getitem__(self, item):
function compute_score (line 182) | def compute_score(data_source: str, solution_str: str, ground_truth: str...
FILE: verl_distillation/recipe/entropy/entropy_ray_trainer.py
class RayEntropyTrainer (line 42) | class RayEntropyTrainer(RayPPOTrainer):
method compute_kl_related_metrics (line 47) | def compute_kl_related_metrics(self, batch: DataProto, timing_raw: dict):
method fit (line 66) | def fit(self):
FILE: verl_distillation/recipe/entropy/main_entropy.py
function main (line 27) | def main(config):
function run_ppo (line 31) | def run_ppo(config) -> None:
function merge_dict (line 52) | def merge_dict(a: dict, b: dict) -> dict:
class TaskRunner (line 68) | class TaskRunner:
method run (line 69) | def run(self, config):
function create_rl_dataset (line 193) | def create_rl_dataset(data_paths, data_config, tokenizer, processor, max...
function create_rl_sampler (line 232) | def create_rl_sampler(data_config, dataset):
FILE: verl_distillation/recipe/entropy/reward.py
function load_reward_manager (line 26) | def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
function compute_reward_async (line 80) | def compute_reward_async(data: DataProto, config, tokenizer):
FILE: verl_distillation/recipe/entropy/reward_score/__init__.py
function _default_compute_score (line 21) | def _default_compute_score(
FILE: verl_distillation/recipe/entropy/reward_score/entropy_math/__init__.py
function timeout_ours (line 40) | def timeout_ours(timeout_seconds: int = 8):
function mathd_normalize_answer (line 67) | def mathd_normalize_answer(answer: Optional[str]) -> Optional[str]:
function _strip_string (line 220) | def _strip_string(string):
function normalize_final_answer (line 440) | def normalize_final_answer(final_answer: str) -> str:
function repeatness (line 477) | def repeatness(s: str):
class timeout (line 520) | class timeout:
method __init__ (line 521) | def __init__(self, seconds=1, error_message="Timeout"):
method handle_timeout (line 525) | def handle_timeout(self, signum, frame):
method __enter__ (line 528) | def __enter__(self):
method __exit__ (line 532) | def __exit__(self, type, value, traceback):
function latex_eval (line 536) | def latex_eval(latex):
function numeric_equal (line 542) | def numeric_equal(prediction: float, reference: float):
function symbolic_equal (line 553) | def symbolic_equal(a, b):
function _is_latex_equal (line 609) | def _is_latex_equal(str1, str2):
function is_latex_equal (line 629) | def is_latex_equal(given_answer: str, ground_truth: str) -> bool:
function is_value_equal (line 682) | def is_value_equal(given_answer: str, ground_truth: str) -> bool:
function _sympy_parse (line 701) | def _sympy_parse(expr: str):
function _parse_latex (line 710) | def _parse_latex(expr: str) -> str:
function _is_float (line 728) | def _is_float(num: str) -> bool:
function _is_int (line 736) | def _is_int(x: float) -> bool:
function _is_frac (line 743) | def _is_frac(expr: str) -> bool:
function _str_is_int (line 747) | def _str_is_int(x: str) -> bool:
function _str_to_int (line 756) | def _str_to_int(x: str) -> bool:
function _inject_implicit_mixed_number (line 762) | def _inject_implicit_mixed_number(step: str):
function _strip_properly_formatted_commas (line 772) | def _strip_properly_formatted_commas(expr: str):
function _normalize (line 783) | def _normalize(expr: str) -> str:
function count_unknown_letters_in_expr (line 856) | def count_unknown_letters_in_expr(expr: str):
function should_allow_eval (line 863) | def should_allow_eval(expr: str):
function are_equal_under_sympy (line 880) | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized...
function split_tuple (line 894) | def split_tuple(expr: str):
function last_boxed_only_string (line 913) | def last_boxed_only_string(string):
function remove_boxed (line 940) | def remove_boxed(s):
function extract_boxed_answer (line 950) | def extract_boxed_answer(solution: str) -> str:
function grade_answer_sympy (line 957) | def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:
function grade_answer_mathd (line 997) | def grade_answer_mathd(given_answer: str, ground_truth: str) -> bool:
function extract_answer (line 1007) | def extract_answer(passage: str) -> str:
function grade (line 1013) | def grade(model_answer: str, gt_answer: str, fast: bool = True):
function compute_score (line 1027) | def compute_score(model_response, gt_answer, fast=False):
FILE: verl_distillation/recipe/entropy/reward_score/entropy_math/grader.py
function is_digit (line 109) | def is_digit(s):
function normalize (line 121) | def normalize(answer, pi) -> str:
function handle_base (line 141) | def handle_base(x) -> str:
function handle_pi (line 150) | def handle_pi(string, pi):
function math_equal (line 174) | def math_equal(
function symbolic_equal (line 324) | def symbolic_equal(a, b, tolerance, timeout=10.0):
function format_intervals (line 362) | def format_intervals(prediction):
FILE: verl_distillation/recipe/entropy/reward_score/entropy_math/math_normalize.py
function normalize_answer (line 44) | def normalize_answer(answer: Optional[str]) -> Optional[str]:
function _fix_fracs (line 58) | def _fix_fracs(string):
function _fix_a_slash_b (line 90) | def _fix_a_slash_b(string):
function _remove_right_units (line 105) | def _remove_right_units(string):
function _fix_sqrt (line 115) | def _fix_sqrt(string):
function _strip_string (line 130) | def _strip_string(string):
FILE: verl_distillation/recipe/fapo/prepare_fapo_data.py
function example_map_fn (line 27) | def example_map_fn(example, idx, process_fn, data_source, ability, split):
function build_aime2024_dataset (line 39) | def build_aime2024_dataset():
function build_aime2025_dataset (line 53) | def build_aime2025_dataset():
function build_gpqa_diamond_dataset (line 67) | def build_gpqa_diamond_dataset():
function build_dapo_train_dataset (line 107) | def build_dapo_train_dataset():
FILE: verl_distillation/recipe/fapo/reward_fn_genrm.py
function parse_ans (line 20) | def parse_ans(
function compute_score_fapo_genrm (line 35) | def compute_score_fapo_genrm(
FILE: verl_distillation/recipe/fapo/reward_fn_reasoning.py
function verify (line 29) | def verify(
function compute_score_baseline (line 45) | async def compute_score_baseline(
function generate_aiohttp (line 77) | async def generate_aiohttp(router_address: str, prompt_ids: list[int], s...
function compute_score_fapo (line 97) | async def compute_score_fapo(
FILE: verl_distillation/recipe/fapo/reward_fn_reasoning_remote.py
function verify (line 22) | def verify(
function compute_score_baseline (line 37) | def compute_score_baseline(
function chat_completions_aiohttp (line 75) | async def chat_completions_aiohttp(address, **chat_complete_request):
function judge_fp_process (line 95) | def judge_fp_process(response, return_err_step=False):
function compute_score_fapo (line 109) | async def compute_score_fapo(data_source, solution_str, ground_truth, ex...
FILE: verl_distillation/recipe/fully_async_policy/agent_loop/agent_loop.py
class FullyAsyncLLMServerManager (line 43) | class FullyAsyncLLMServerManager(AsyncLLMServerManager):
method generate_for_partial (line 44) | async def generate_for_partial(self, request_id, prompt_ids, sampling_...
class FullyAsyncAgentLoopOutput (line 56) | class FullyAsyncAgentLoopOutput(AgentLoopOutput):
class FullyAsyncAgentLoopWorker (line 70) | class FullyAsyncAgentLoopWorker(AgentLoopWorkerBase):
method __init__ (line 71) | def __init__(
method generate_sequences_no_post (line 77) | async def generate_sequences_no_post(
method _partial_run_agent_loop (line 127) | async def _partial_run_agent_loop(
class FullyAsyncAgentLoopManager (line 157) | class FullyAsyncAgentLoopManager(AgentLoopManager):
method __init__ (line 158) | def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = ...
method create (line 173) | async def create(cls, config: DictConfig, worker_group: RayWorkerGroup...
method _async_init (line 178) | async def _async_init(self):
method _initialize_llm_servers_async (line 188) | async def _initialize_llm_servers_async(self):
method generate_single_sample_async (line 217) | async def generate_single_sample_async(
method _select_best_worker (line 236) | def _select_best_worker(self):
method cancel (line 245) | async def cancel(self):
method resume (line 248) | async def resume(self):
method wake_up (line 251) | async def wake_up(self):
method sleep (line 254) | async def sleep(self):
method reset_prefix_cache (line 257) | async def reset_prefix_cache(self):
FILE: verl_distillation/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py
class PartialSingleTurnAgentLoop (line 29) | class PartialSingleTurnAgentLoop(AgentLoopBase):
method __init__ (line 32) | def __init__(self, *args, **kwargs):
method run (line 38) | async def run(self, sampling_params: dict[str, Any], **kwargs) -> Agen...
FILE: verl_distillation/recipe/fully_async_policy/detach_utils.py
function postprocess_agent_loop_outputs (line 29) | def postprocess_agent_loop_outputs(rs: "RolloutSample", tokenizer, confi...
class RolloutSample (line 159) | class RolloutSample:
class ValidateMetrics (line 181) | class ValidateMetrics:
function prepare_single_generation_data (line 190) | def prepare_single_generation_data(batch_dict, global_steps, rollout_n) ...
function process_rollout_log_probs (line 217) | def process_rollout_log_probs(data_proto: DataProto, rollout_log_probs: ...
function merge_rollout_sample (line 249) | def merge_rollout_sample(config, tokenizer, rs: RolloutSample, processor):
function assemble_batch_from_rollout_samples (line 280) | def assemble_batch_from_rollout_samples(
class MetricsAggregator (line 366) | class MetricsAggregator:
method __init__ (line 369) | def __init__(self, total_gpus: int):
method _init_aggregation_rules (line 384) | def _init_aggregation_rules(self) -> dict[str, dict[str, list[str]]]:
method add_step_metrics (line 399) | def add_step_metrics(self, metrics: dict[str, Any], sample_count: int,...
method _get_aggregation_type (line 415) | def _get_aggregation_type(self, metric_name: str) -> str:
method _aggregate_single_metric (line 437) | def _aggregate_single_metric(self, metric_name: str, values: list[floa...
method get_aggregated_metrics (line 476) | def get_aggregated_metrics(self) -> dict[str, Any]:
method _special_metrics_aggergate (line 495) | def _special_metrics_aggergate(self, aggregated: dict[str, Any]) -> di...
method reset (line 515) | def reset(self):
method get_current_stats (line 522) | def get_current_stats(self) -> dict[str, Any]:
FILE: verl_distillation/recipe/fully_async_policy/fsdp2_utils.py
function fsdp2_sharded_save_to_cpu (line 28) | def fsdp2_sharded_save_to_cpu(
function fsdp2_sharded_load_from_cpu (line 70) | def fsdp2_sharded_load_from_cpu(
FILE: verl_distillation/recipe/fully_async_policy/fsdp_workers.py
function get_inference_model (line 43) | def get_inference_model(rollout):
class DetachNcclSync (line 64) | class DetachNcclSync(AsyncActorRolloutRefWorker):
method _get_actor_params (line 65) | def _get_actor_params(self):
method sync_rollout_weights (line 69) | def sync_rollout_weights(self):
class DetachActorWorker (line 97) | class DetachActorWorker(DetachNcclSync):
method _get_actor_params (line 98) | def _get_actor_params(self):
method get_actor_weights_info (line 109) | def get_actor_weights_info(self):
method save_model_to_cpu (line 129) | def save_model_to_cpu(self, n):
method restore_model_from_cpu (line 135) | def restore_model_from_cpu(self, n):
method clear_cpu_model (line 141) | def clear_cpu_model(self, n):
class DetachAsyncRolloutWorker (line 146) | class DetachAsyncRolloutWorker(DetachNcclSync):
method __init__ (line 147) | def __init__(self, config: DictConfig, role: str):
method set_actor_weights_info (line 152) | def set_actor_weights_info(self, weights_info):
FILE: verl_distillation/recipe/fully_async_policy/fully_async_main.py
function create_resource_pool_manager (line 33) | def create_resource_pool_manager(config, roles: list) -> ResourcePoolMan...
function create_role_worker_mapping (line 72) | def create_role_worker_mapping(config):
class FullyAsyncTaskRunner (line 126) | class FullyAsyncTaskRunner:
method __init__ (line 131) | def __init__(self):
method run (line 136) | def run(self, config):
method _initialize_components (line 141) | def _initialize_components(self, config) -> None:
method _create_rollouter (line 219) | def _create_rollouter(self, config) -> None:
method _create_trainer (line 238) | def _create_trainer(self, config) -> None:
method _run_training_loop (line 261) | def _run_training_loop(self):
function main (line 298) | def main(config):
FILE: verl_distillation/recipe/fully_async_policy/fully_async_rollouter.py
class FullyAsyncRollouter (line 37) | class FullyAsyncRollouter(FullyAsyncRayPPOTrainer):
method __init__ (line 44) | def __init__(
method set_message_queue_client (line 153) | async def set_message_queue_client(self, message_queue_client: Message...
method set_max_required_samples (line 158) | async def set_max_required_samples(self):
method get_rollout_wg (line 183) | def get_rollout_wg(self):
method get_max_queue_size (line 187) | def get_max_queue_size(self):
method get_total_train_steps (line 190) | def get_total_train_steps(self):
method update_param_version (line 193) | async def update_param_version(self, version: int, validate: bool = Fa...
method _validate_config (line 237) | def _validate_config(self):
method init_workers (line 243) | async def init_workers(self):
method _create_actor_rollout_classes (line 256) | def _create_actor_rollout_classes(self):
method _init_models (line 267) | def _init_models(self):
method _create_continuous_iterator (line 272) | def _create_continuous_iterator(self):
method _init_async_rollout_manager (line 281) | async def _init_async_rollout_manager(self):
method _feed_samples (line 293) | async def _feed_samples(self):
method _processor_worker (line 333) | async def _processor_worker(self):
method _process_single_sample_streaming (line 409) | async def _process_single_sample_streaming(self, rollout_sample: Rollo...
method _consumer_worker (line 436) | async def _consumer_worker(self):
method _streaming_generation_main (line 457) | async def _streaming_generation_main(self):
method fit (line 507) | async def fit(self):
method _async_monitor_loop (line 543) | async def _async_monitor_loop(self):
method _should_pause_generation (line 572) | async def _should_pause_generation(self) -> bool:
method pause (line 596) | async def pause(self):
method resume (line 611) | async def resume(self, dependency_ref: ObjectRef = None):
method get_statistics (line 623) | async def get_statistics(self) -> dict:
FILE: verl_distillation/recipe/fully_async_policy/fully_async_trainer.py
class FullyAsyncTrainer (line 39) | class FullyAsyncTrainer(FullyAsyncRayPPOTrainer):
method __init__ (line 45) | def __init__(
method set_message_queue_client (line 110) | def set_message_queue_client(self, message_queue_client: MessageQueueC...
method set_parameter_synchronizer (line 114) | def set_parameter_synchronizer(self, param_synchronizer):
method set_total_train_steps (line 118) | def set_total_train_steps(self, total_train_steps):
method get_actor_wg (line 122) | def get_actor_wg(self):
method _get_samples_from_queue (line 126) | def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]:
method _create_actor_rollout_classes (line 185) | def _create_actor_rollout_classes(self):
method _init_models (line 196) | def _init_models(self):
method _init_async_rollout_manager (line 213) | def _init_async_rollout_manager(self):
method fit (line 216) | def fit(self):
method load_checkpoint (line 309) | def load_checkpoint(self):
method _collect_metrics_from_samples (line 312) | def _collect_metrics_from_samples(self, batch, metrics):
method _trigger_parameter_sync_after_step (line 334) | def _trigger_parameter_sync_after_step(self, validate: bool = False, g...
FILE: verl_distillation/recipe/fully_async_policy/megatron_worker.py
function get_inference_model (line 40) | def get_inference_model(rollout):
class DetachNcclSync (line 61) | class DetachNcclSync(AsyncActorRolloutRefWorker):
method _get_actor_params (line 62) | def _get_actor_params(self):
method sync_rollout_weights (line 66) | def sync_rollout_weights(self):
class DetachActorWorker (line 93) | class DetachActorWorker(DetachNcclSync):
method _get_actor_params_generator (line 94) | def _get_actor_params_generator(self):
method get_actor_weights_info (line 110) | def get_actor_weights_info(self):
class DetachAsyncRolloutWorker (line 124) | class DetachAsyncRolloutWorker(DetachNcclSync):
method __init__ (line 125) | def __init__(self, config: DictConfig, role: str):
method set_actor_weights_info (line 130) | def set_actor_weights_info(self, weights_info):
FILE: verl_distillation/recipe/fully_async_policy/message_queue.py
class MessageQueue (line 27) | class MessageQueue:
method __init__ (line 32) | def __init__(self, config: DictConfig, max_queue_size: int = 1000):
method put_sample (line 67) | async def put_sample(self, sample: Any, param_version: int) -> bool:
method get_sample (line 98) | async def get_sample(self) -> Any | None:
method update_param_version (line 118) | async def update_param_version(self, version: int):
method get_queue_size (line 125) | async def get_queue_size(self) -> int:
method get_statistics (line 130) | async def get_statistics(self) -> dict[str, Any]:
method clear_queue (line 143) | async def clear_queue(self):
method shutdown (line 150) | async def shutdown(self):
method get_memory_usage (line 158) | async def get_memory_usage(self) -> dict:
method put_validate (line 190) | async def put_validate(self, data):
method get_validate (line 194) | async def get_validate(self):
class MessageQueueClient (line 202) | class MessageQueueClient:
method __init__ (line 205) | def __init__(self, queue_actor: Any):
method put_sample (line 208) | async def put_sample(self, sample: Any, param_version: int) -> bool:
method put_validate (line 213) | async def put_validate(self, data: Any) -> bool:
method get_validate_sync (line 217) | def get_validate_sync(self) -> Any | None:
method get_sample (line 220) | async def get_sample(self) -> Any | None:
method get_queue_size (line 225) | async def get_queue_size(self) -> int:
method get_statistics (line 230) | async def get_statistics(self) -> dict[str, Any]:
method clear_queue (line 235) | async def clear_queue(self):
method shutdown (line 240) | async def shutdown(self):
method get_memory_usage (line 245) | async def get_memory_usage(self) -> dict:
method put_sample_sync (line 251) | def put_sample_sync(self, sample: Any, param_version: int) -> bool:
method get_sample_sync (line 255) | def get_sample_sync(self) -> Any | None:
method get_statistics_sync (line 259) | def get_statistics_sync(self) -> dict[str, Any]:
method update_param_version_sync (line 263) | def update_param_version_sync(self, version: int):
FILE: verl_distillation/recipe/fully_async_policy/param_sync.py
class ParameterSynchronizer (line 25) | class ParameterSynchronizer:
method __init__ (line 32) | def __init__(self, config, trainer, rollouter, mq):
method get_current_param_version (line 53) | def get_current_param_version(self) -> int:
method get_weights_info (line 57) | def get_weights_info(self):
method _init_weights_info (line 61) | def _init_weights_info(self):
method _init_sync_group (line 65) | def _init_sync_group(self):
method sync_weights (line 76) | def sync_weights(self, version, validate=False, global_steps=0):
method wait_last_valid (line 98) | def wait_last_valid(self):
FILE: verl_distillation/recipe/fully_async_policy/ray_trainer.py
class FullyAsyncRayPPOTrainer (line 53) | class FullyAsyncRayPPOTrainer(RayPPOTrainer):
method init_workers (line 54) | def init_workers(self):
method _init_resource_pools (line 67) | def _init_resource_pools(self):
method _create_worker_classes (line 72) | def _create_worker_classes(self):
method _create_actor_rollout_classes (line 78) | def _create_actor_rollout_classes(self):
method _create_critic_class (line 81) | def _create_critic_class(self):
method _create_reference_policy_class (line 89) | def _create_reference_policy_class(self):
method _create_reward_model_class (line 101) | def _create_reward_model_class(self):
method _init_worker_groups (line 109) | def _init_worker_groups(self):
method _init_models (line 143) | def _init_models(self):
method _init_async_rollout_manager (line 160) | def _init_async_rollout_manager(self):
method fit (line 163) | def fit(self):
method _prepare_generate_batch (line 306) | def _prepare_generate_batch(self, batch_dict):
method _post_generate_batch (line 319) | def _post_generate_batch(self, batch, gen_batch_output, metrics):
method _process_batch_common (line 339) | def _process_batch_common(self, batch, metrics, timing_raw, local_trig...
method _log_rollout (line 465) | def _log_rollout(self, batch, reward_extra_infos_dict, timing_raw):
method _validate_metrics (line 490) | def _validate_metrics(self, is_last_step, last_val_metrics, metrics, t...
method _check_save_checkpoint (line 503) | def _check_save_checkpoint(self, is_last_step, timing_raw):
method _collect_metrics (line 524) | def _collect_metrics(self, batch, epoch, metrics, timing_raw):
method _post_batch_processing (line 542) | def _post_batch_processing(self, batch: DataProto):
FILE: verl_distillation/recipe/fully_async_policy/unittest/simple_streaming_demo.py
class SimpleStreamingSystem (line 20) | class SimpleStreamingSystem:
method __init__ (line 23) | def __init__(self, max_concurrent_tasks: int = 4):
method data_stream (line 30) | async def data_stream(self):
method add_data_stream (line 47) | async def add_data_stream(self, data_list: list[dict]):
method _process_data_async (line 61) | async def _process_data_async(self, data_item: dict):
method _submit_worker (line 85) | async def _submit_worker(self):
method _consumer_worker (line 120) | async def _consumer_worker(self):
method run_demo (line 140) | async def run_demo(self):
function main (line 169) | async def main():
FILE: verl_distillation/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py
class vLLMHttpServerForPartial (line 37) | class vLLMHttpServerForPartial(vLLMHttpServerBase):
method __init__ (line 38) | def __init__(
method _generate_step (line 57) | async def _generate_step(
method generate_for_partial (line 79) | async def generate_for_partial(
method cancel (line 120) | async def cancel(self):
method resume (line 126) | async def resume(self):
method reset_prefix_cache (line 130) | async def reset_prefix_cache(self):
class FullyAsyncvLLMReplica (line 135) | class FullyAsyncvLLMReplica(vLLMReplica):
method __init__ (line 136) | def __init__(
method cancel (line 147) | async def cancel(self):
method resume (line 151) | async def resume(self):
method reset_prefix_cache (line 155) | async def reset_prefix_cache(self):
FILE: verl_distillation/recipe/genrm_remote/reward_function.py
function get_response (line 45) | def get_response(problem, solution_str, ground_truth):
function compute_reward (line 68) | def compute_reward(response):
function compute_score (line 80) | def compute_score(data_source, solution_str, ground_truth, extra_info):
function compute_score_batch (line 99) | def compute_score_batch(data_sources, solution_strs, ground_truths, extr...
FILE: verl_distillation/recipe/infigui-g1/reward_fn.py
function extract_think_format (line 30) | def extract_think_format(predict_str: str) -> None | dict[str, str]:
function extract_and_parse_json (line 73) | def extract_and_parse_json(input_string, wrapper):
function _extract_verifiable_answer (line 122) | def _extract_verifiable_answer(answer):
function _format_reward (line 152) | def _format_reward(answer):
function _check_collinear (line 175) | def _check_collinear(points_2d):
function _accuracy_reward (line 205) | def _accuracy_reward(answer, ground_truth):
function calculate_point_reward (line 260) | def calculate_point_reward(solution_str, ground_truth, extra_info=None, ...
function aer_gui_reward_function (line 340) | def aer_gui_reward_function(data_source, solution_str, ground_truth, ext...
FILE: verl_distillation/recipe/langgraph_agent/chat_model.py
class MaxTokenExceededError (line 47) | class MaxTokenExceededError(Exception):
class ChatModel (line 53) | class ChatModel(BaseChatModel):
method bind_tools (line 81) | def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, ...
method with_structured_output (line 98) | def with_structured_output(
method _generate (line 108) | def _generate(
method _agenerate (line 116) | async def _agenerate(
method _llm_type (line 151) | def _llm_type(self) -> str:
method _preprocess (line 155) | async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any...
method _postprocess (line 253) | async def _postprocess(
class TruncateStructuredTool (line 320) | class TruncateStructuredTool(StructuredTool):
method _arun (line 329) | async def _arun(
function convert_to_agent_output (line 350) | def convert_to_agent_output(messages: list[BaseMessage], response_length...
FILE: verl_distillation/recipe/langgraph_agent/example/create_dataset.py
function generate_math_expression (line 25) | def generate_math_expression(min_terms=2, max_terms=5, min_number=1, max...
function test (line 80) | def test():
function calculate (line 98) | def calculate(expression: str) -> float:
function generate_data (line 213) | def generate_data(total_num_dataset, split):
FILE: verl_distillation/recipe/langgraph_agent/example/math_expression.py
function calculate (line 20) | def calculate(a: int, b: int, operand: str) -> int:
class MathExpressionReactAgentLoop (line 35) | class MathExpressionReactAgentLoop(ReactAgentLoop):
method init_class (line 37) | def init_class(cls, config, tokenizer, **kwargs):
FILE: verl_distillation/recipe/langgraph_agent/react_agent_loop.py
function call_model (line 36) | async def call_model(state: MessagesState, config: RunnableConfig):
function should_continue (line 47) | def should_continue(state: MessagesState, config: RunnableConfig) -> Lit...
class ReactAgentLoop (line 71) | class ReactAgentLoop(AgentLoopBase):
method init_class (line 73) | def init_class(cls, config, tokenizer, **kwargs):
method build_graph (line 83) | def build_graph(cls) -> StateGraph:
method run (line 102) | async def run(self, sampling_params: dict[str, Any], **kwargs) -> Agen...
FILE: verl_distillation/recipe/langgraph_agent/test_react_agent_loop.py
function init_config (line 30) | def init_config() -> DictConfig:
function get_current_temperature (line 53) | def get_current_temperature(location: str, unit: str = "celsius"):
function get_temperature_date (line 72) | def get_temperature_date(location: str, date: str, unit: str = "celsius"):
class TestReactAgentLoop (line 92) | class TestReactAgentLoop(ReactAgentLoop):
method init_class (line 94) | def init_class(cls, config, tokenizer, **kwargs):
function test_react_agent (line 100) | def test_react_agent(init_config):
FILE: verl_distillation/recipe/minicpmo/rl_dataset.py
function build_transform (line 39) | def build_transform():
function build_image_bound (line 50) | def build_image_bound(input_ids, tokenizer, new_schema=True, logger=None):
function preprocess (line 70) | def preprocess(
function slice_image (line 204) | def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_siz...
function ensure_divide (line 255) | def ensure_divide(length, patch_size):
function find_best_resize (line 259) | def find_best_resize(original_size, scale_resolution, patch_size, allow_...
function get_refine_size (line 270) | def get_refine_size(original_size, grid, scale_resolution, patch_size, a...
function split_to_patches (line 292) | def split_to_patches(image, grid):
function get_grid_placeholder (line 309) | def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
function reshape_by_patch (line 330) | def reshape_by_patch(image_tensor, patch_size):
function init_minicpmo_config (line 343) | def init_minicpmo_config(processor, config):
function process_minicpmo_data (line 358) | def process_minicpmo_data(
class RLHFDataset (line 402) | class RLHFDataset(Dataset):
method __init__ (line 419) | def __init__(
method _download (line 457) | def _download(self, use_origin_parquet=False):
method _read_files_and_tokenize (line 464) | def _read_files_and_tokenize(self):
method resume_dataset_state (line 474) | def resume_dataset_state(self):
method __len__ (line 483) | def __len__(self):
method _build_messages (line 486) | def _build_messages(self, example: dict):
method __getitem__ (line 489) | def __getitem__(self, item):
method __getstate__ (line 563) | def __getstate__(self):
FILE: verl_distillation/recipe/one_step_off_policy/distributed_util.py
function stateless_init_process_group (line 18) | def stateless_init_process_group(master_address, master_port, rank, worl...
FILE: verl_distillation/recipe/one_step_off_policy/fsdp_workers.py
class ActorRolloutRefWorker (line 59) | class ActorRolloutRefWorker(ARRWorker):
method create_weight_sync_group (line 61) | def create_weight_sync_group(self, master_address, master_port, rank_o...
method _get_actor_params (line 71) | def _get_actor_params(self):
method sync_rollout_weights (line 82) | def sync_rollout_weights(self):
method update_weights (line 118) | async def update_weights(self, inference_engine, params):
method get_actor_weights_info (line 132) | def get_actor_weights_info(self):
class RolloutWorker (line 152) | class RolloutWorker(ActorRolloutRefWorker):
method __init__ (line 153) | def __init__(self, config: DictConfig, role: str):
method init_model (line 189) | def init_model(self):
method async_generate_sequences (line 277) | def async_generate_sequences(self, prompts):
method set_actor_weights_info (line 323) | def set_actor_weights_info(self, weights_info):
class AsyncActorRolloutRefWorker (line 328) | class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
method __init__ (line 329) | def __init__(self, *args, **kwargs):
FILE: verl_distillation/recipe/one_step_off_policy/main_ppo.py
function main (line 37) | def main(config):
function run_ppo (line 42) | def run_ppo(config) -> None:
class TaskRunner (line 78) | class TaskRunner:
method run (line 79) | def run(self, config):
FILE: verl_distillation/recipe/one_step_off_policy/megatron_workers.py
class ActorRolloutRefWorker (line 42) | class ActorRolloutRefWorker(ARRWorker):
method __init__ (line 43) | def __init__(self, config: DictConfig, role: str):
method create_weight_sync_group (line 52) | def create_weight_sync_group(self, master_address, master_port, rank_o...
method _get_actor_params_generator (line 62) | def _get_actor_params_generator(self):
method sync_rollout_weights (line 82) | def sync_rollout_weights(self):
method get_actor_weights_info (line 110) | def get_actor_weights_info(self):
class RolloutWorker (line 124) | class RolloutWorker(ActorRolloutRefWorker):
method __init__ (line 125) | def __init__(self, config: DictConfig, role: str):
method init_model (line 130) | def init_model(self):
method async_generate_sequences (line 198) | def async_generate_sequences(self, *args, **kwargs):
method set_actor_weights_info (line 202) | def set_actor_weights_info(self, weights_info):
class AsyncActorRolloutRefWorker (line 207) | class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
method __init__ (line 208) | def __init__(self, *args, **kwargs):
FILE: verl_distillation/recipe/one_step_off_policy/ray_trainer.py
class GenerationBatchFuture (line 58) | class GenerationBatchFuture:
method __init__ (line 63) | def __init__(self, epoch, batch, gen_batch_output, future_reward=None):
method get (line 75) | def get(self):
class OneStepOffRayTrainer (line 95) | class OneStepOffRayTrainer(RayPPOTrainer):
method __init__ (line 98) | def __init__(
method _validate (line 164) | def _validate(self):
method init_workers (line 170) | def init_workers(self):
method create_weight_sync_group (line 278) | def create_weight_sync_group(self):
method sync_rollout_weights (line 297) | def sync_rollout_weights(self):
method _create_continuous_iterator (line 302) | def _create_continuous_iterator(self):
method _async_gen_next_batch (line 311) | def _async_gen_next_batch(self, continuous_iterator):
method _launch_individual_rewards (line 363) | def _launch_individual_rewards(gen_batch_output, config, tokenizer, or...
method fit (line 400) | def fit(self):
FILE: verl_distillation/recipe/one_step_off_policy/sglang_sharding_manager.py
class SGLangShardingManager (line 32) | class SGLangShardingManager(BaseShardingManager):
method __init__ (line 34) | def __init__(self, device_mesh: DeviceMesh):
method __enter__ (line 44) | def __enter__(self):
method __exit__ (line 48) | def __exit__(self, exc_type, exc_value, traceback):
method preprocess_data (line 53) | def preprocess_data(self, data: DataProto) -> DataProto:
method postprocess_data (line 65) | def postprocess_data(self, data: DataProto) -> DataProto:
FILE: verl_distillation/recipe/one_step_off_policy/utils.py
function need_critic (line 22) | def need_critic(config: DictConfig) -> bool:
FILE: verl_distillation/recipe/one_step_off_policy/vllm_sharding_manager.py
class VLLMShardingManager (line 33) | class VLLMShardingManager(BaseShardingManager):
method __init__ (line 35) | def __init__(self, inference_engine, device_mesh: DeviceMesh):
method __enter__ (line 49) | def __enter__(self):
method __exit__ (line 53) | def __exit__(self, exc_type, exc_value, traceback):
method preprocess_data (line 58) | def preprocess_data(self, data: DataProto) -> DataProto:
method postprocess_data (line 69) | def postprocess_data(self, data: DataProto) -> DataProto:
FILE: verl_distillation/recipe/onpolicy_distill/main_onpolicy_distill.py
function create_rl_dataset (line 28) | def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_...
class OnPolicyDistillTaskRunner (line 78) | class OnPolicyDistillTaskRunner(TaskRunner):
method run (line 80) | def run(self, config):
function main (line 193) | def main(config):
FILE: verl_distillation/recipe/onpolicy_distill/onpolicy_distill_trainer.py
class RayOnPolicyDistillTrainer (line 44) | class RayOnPolicyDistillTrainer(RayPPOTrainer):
method compute_kl_related_metrics (line 49) | def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, ...
method fit (line 75) | def fit(self):
FILE: verl_distillation/recipe/open_math_reasoning/compute_score.py
function compute_score_data_source (line 16) | def compute_score_data_source(data_source, response, ground_truth):
FILE: verl_distillation/recipe/open_math_reasoning/prepare_eval_dataset.py
function make_map_fn (line 29) | def make_map_fn(data_source):
FILE: verl_distillation/recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py
function make_map_fn (line 46) | def make_map_fn(split):
FILE: verl_distillation/recipe/prime/main_prime.py
function main (line 43) | def main(config):
function run_prime (line 47) | def run_prime(config, compute_score=None):
function main_task (line 62) | def main_task(config, compute_score=None):
FILE: verl_distillation/recipe/prime/prime_core_algos.py
function compute_rloo_advantage_return (line 21) | def compute_rloo_advantage_return(data: verl.DataProto, response_mask: t...
function compute_ce_dpo_loss_rm (line 82) | def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):
function compute_detach_dpo_loss_rm (line 88) | def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, re...
function compute_dpo_accuracy (line 119) | def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_sampl...
function compute_dpo_abs_accuracy (line 146) | def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_s...
FILE: verl_distillation/recipe/prime/prime_dp_rm.py
class DataParallelPRIMERewardModel (line 38) | class DataParallelPRIMERewardModel:
method __init__ (line 39) | def __init__(self, config, reward_module: nn.Module, ref_module: nn.Mo...
method _forward_micro_batch (line 51) | def _forward_micro_batch(self, micro_batch, prompt_length):
method _optimizer_step (line 230) | def _optimizer_step(self):
method prime_norm (line 242) | def prime_norm(self, token_level_scores):
method compute_rm_score (line 248) | def compute_rm_score(self, data: DataProto):
method update_rm (line 291) | def update_rm(self, data: DataProto):
FILE: verl_distillation/recipe/prime/prime_fsdp_workers.py
class PRIMERewardModelWorker (line 53) | class PRIMERewardModelWorker(Worker):
method __init__ (line 54) | def __init__(self, config):
method _build_reward_ref_model_optimizer (line 89) | def _build_reward_ref_model_optimizer(self, config):
method init_model (line 241) | def init_model(self):
method compute_rm_score (line 273) | def compute_rm_score(self, data: DataProto):
method update_rm (line 308) | def update_rm(self, data: DataProto):
method save_checkpoint (line 350) | def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, m...
method load_checkpoint (line 365) | def load_checkpoint(self, local_path, del_local_after_load=True):
FILE: verl_distillation/recipe/prime/prime_ray_trainer.py
function compute_advantage (line 43) | def compute_advantage(data: DataProto, adv_estimator, config):
function compute_data_metrics (line 59) | def compute_data_metrics(batch, use_critic=True):
function compute_response_mask (line 120) | def compute_response_mask(data: DataProto):
function compute_timing_metrics (line 127) | def compute_timing_metrics(batch, timing_raw):
class RayPRIMETrainer (line 147) | class RayPRIMETrainer(RayPPOTrainer):
method __init__ (line 154) | def __init__(
method _create_dataloader (line 180) | def _create_dataloader(self, *args, **kwargs):
method _save_checkpoint (line 236) | def _save_checkpoint(self):
method _load_checkpoint (line 281) | def _load_checkpoint(self):
method compute_reward (line 334) | def compute_reward(self, batch: DataProto, n_samples: int):
method fit (line 373) | def fit(self):
method filter_and_downsample (line 557) | def filter_and_downsample(self, scores, batch: DataProto):
FILE: verl_distillation/recipe/r1/data_process.py
function example_map_fn (line 27) | def example_map_fn(example, idx, process_fn, data_source, ability, split):
function build_aime2024_dataset (line 39) | def build_aime2024_dataset():
function build_gpqa_dimond_dataset (line 53) | def build_gpqa_dimond_dataset():
function build_cnmo2024_dataset (line 84) | def build_cnmo2024_dataset():
function build_livecodebench_dataset (line 107) | def build_livecodebench_dataset():
FILE: verl_distillation/recipe/r1/main_eval.py
function process_item (line 34) | def process_item(config, data_source, response_lst, reward_data):
function main (line 42) | def main(config):
FILE: verl_distillation/recipe/r1/reward_score.py
function reward_func (line 16) | def reward_func(data_source, solution_str, ground_truth, extra_info=None):
FILE: verl_distillation/recipe/r1/tasks/gpqa.py
function compute_score (line 21) | def compute_score(solution_str, ground_truth) -> float:
FILE: verl_distillation/recipe/r1/tasks/livecodebench.py
function _temp_run (line 25) | def _temp_run(in_outs, generation, debug, result, metadata_list, timeout):
function check_correctness (line 31) | def check_correctness(in_outs, generation, timeout, debug=True):
function compute_score (line 55) | def compute_score(completion, test_cases):
FILE: verl_distillation/recipe/r1/tasks/math_reward.py
function compute_score (line 23) | def compute_score(model_output: str, ground_truth: str) -> bool:
FILE: verl_distillation/recipe/retool/retool.py
class CustomSandboxFusionTool (line 29) | class CustomSandboxFusionTool(SandboxFusionTool):
method __init__ (line 30) | def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
method execute (line 35) | async def execute(self, instance_id: str, parameters: dict[str, Any], ...
class CustomRLHFDataset (line 64) | class CustomRLHFDataset(RLHFDataset):
method _read_files_and_tokenize (line 67) | def _read_files_and_tokenize(self):
method map_fn (line 84) | def map_fn(self, row: dict, *, data_source: str = None):
method map_fn2 (line 100) | def map_fn2(self, row: dict):
function compute_score (line 107) | def compute_score(data_source, solution_str, ground_truth, extra_info):
FILE: verl_distillation/recipe/retool/retool_sft_preprocess.py
function extract_code_message (line 29) | def extract_code_message(content: str) -> tuple[dict[str, Any], str]:
function extract_answer_message (line 58) | def extract_answer_message(content: str) -> tuple[dict[str, Any], str]:
function extract_interpreter_message (line 74) | def extract_interpreter_message(content: str) -> tuple[dict[str, Any], s...
function process (line 90) | def process(row: dict, *, tools: str):
FILE: verl_distillation/recipe/spin/core_algos.py
class AdaptiveKLController (line 21) | class AdaptiveKLController:
method __init__ (line 27) | def __init__(self, init_kl_coef, target_kl, horizon):
method update (line 32) | def update(self, current_kl, n_steps):
class FixedKLController (line 39) | class FixedKLController:
method __init__ (line 42) | def __init__(self, kl_coef):
method update (line 45) | def update(self, current_kl, n_steps):
function get_kl_controller (line 49) | def get_kl_controller(kl_ctrl):
function compute_onlinedpo_pref (line 59) | def compute_onlinedpo_pref(
function compute_online_dpo_loss (line 131) | def compute_online_dpo_loss(
function get_batch_logps (line 161) | def get_batch_logps(
FILE: verl_distillation/recipe/spin/dp_actor.py
class SPINDataParallelPPOActor (line 33) | class SPINDataParallelPPOActor(DataParallelPPOActor):
method compute_log_prob (line 34) | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
method update_policy_dpo_with_ref (line 92) | def update_policy_dpo_with_ref(self, data: DataProto):
FILE: verl_distillation/recipe/spin/fsdp_workers.py
function create_device_mesh (line 57) | def create_device_mesh(world_size, fsdp_size):
function get_sharding_strategy (line 67) | def get_sharding_strategy(device_mesh):
class SPINRolloutRefWorker (line 79) | class SPINRolloutRefWorker(ActorRolloutRefWorker):
method init_model (line 81) | def init_model(self):
method compute_ref_log_prob (line 169) | def compute_ref_log_prob(self, data: DataProto):
method compute_log_prob (line 194) | def compute_log_prob(self, data: DataProto):
method update_actor_dpo (line 227) | def update_actor_dpo(self, data: DataProto):
class RewardModelWorker (line 289) | class RewardModelWorker(Worker):
method __init__ (line 294) | def __init__(self, config):
method _build_model (line 334) | def _build_model(self, config):
method init_model (line 401) | def init_model(self):
method _forward_micro_batch (line 406) | def _forward_micro_batch(self, micro_batch):
method _expand_to_token_level (line 461) | def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
method _switch_chat_template (line 476) | def _switch_chat_template(self, data: DataProto):
method compute_rm_score (line 542) | def compute_rm_score(self, data: DataProto):
FILE: verl_distillation/recipe/spin/main_spin.py
function main (line 28) | def main(config):
function run_ppo (line 32) | def run_ppo(config) -> None:
class TaskRunner (line 49) | class TaskRunner:
method run (line 50) | def run(self, config):
FILE: verl_distillation/recipe/spin/spin_trainer.py
class ResourcePoolManager (line 49) | class ResourcePoolManager:
method create_resource_pool (line 59) | def create_resource_pool(self):
method get_resource_pool (line 72) | def get_resource_pool(self, role: Role) -> RayResourcePool:
method get_n_gpus (line 76) | def get_n_gpus(self) -> int:
method _check_resource_available (line 80) | def _check_resource_available(self):
function _compute_response_info (line 111) | def _compute_response_info(batch: DataProto) -> dict[str, Any]:
function compute_dpo_data_metrics (line 158) | def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]:
function apply_kl_penalty (line 247) | def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLCont...
function compute_response_mask (line 277) | def compute_response_mask(data: DataProto):
function compute_onlineDPO_pref (line 284) | def compute_onlineDPO_pref(data: DataProto):
function _timer (line 323) | def _timer(name: str, timing_raw: dict[str, float]):
class RaySPINTrainer (line 329) | class RaySPINTrainer:
method __init__ (line 336) | def __init__(
method _create_dataloader (line 383) | def _create_dataloader(self, train_dataset, val_dataset, collate_fn, t...
method _maybe_log_val_generations (line 463) | def _maybe_log_val_generations(self, inputs, outputs, scores):
method _validate (line 487) | def _validate(self):
method init_workers (line 614) | def init_workers(self):
method _save_checkpoint (line 694) | def _save_checkpoint(self):
method _load_checkpoint (line 749) | def _load_checkpoint(self):
method _balance_batch (line 806) | def _balance_batch(self, batch: DataProto, metrics, logging_prefix="gl...
method fit_dpo (line 823) | def fit_dpo(self): # Renamed for clarity as standard PPO loop
FILE: verl_distillation/recipe/spin/utils.py
function validate_config (line 18) | def validate_config(
FILE: verl_distillation/recipe/sppo/config.py
class SPPOActorConfig (line 21) | class SPPOActorConfig(FSDPActorConfig):
FILE: verl_distillation/recipe/sppo/dp_actor.py
function compute_sppo_loss (line 34) | def compute_sppo_loss(
class DataParallelSPPOActor (line 60) | class DataParallelSPPOActor(DataParallelPPOActor):
method update_policy (line 62) | def update_policy(self, data: DataProto):
FILE: verl_distillation/recipe/sppo/main_sppo.py
function main (line 34) | def main(config):
function run_ppo (line 38) | def run_ppo(config) -> None:
class TaskRunner (line 59) | class TaskRunner:
method run (line 60) | def run(self, config):
FILE: verl_distillation/recipe/sppo/sppo_ray_trainer.py
function softmean (line 50) | def softmean(x: torch.Tensor, beta: float, dim: int = -1, keepdim: bool ...
function compute_advantage (line 68) | def compute_advantage(data: DataProto, beta=1.0):
class RaySPPOTrainer (line 76) | class RaySPPOTrainer(RayPPOTrainer):
method __init__ (line 83) | def __init__(
method fit (line 127) | def fit(self):
FILE: verl_distillation/recipe/sppo/sppo_worker.py
class SPPOActorRolloutRefWorker (line 33) | class SPPOActorRolloutRefWorker(ActorRolloutRefWorker):
method init_model (line 40) | def init_model(self):
FILE: verl_distillation/recipe/transfer_queue/agent_loop.py
class AgentLoopManager (line 22) | class AgentLoopManager(agent_loop.AgentLoopManager):
method generate_sequences (line 23) | def generate_sequences(self, prompts: BatchMeta) -> BatchMeta:
method _performance_metrics (line 57) | def _performance_metrics(self, metrics: list[list[dict[str, str]]], ou...
method create_transferqueue_client (line 70) | def create_transferqueue_client(self, controller_infos, storage_infos,...
FILE: verl_distillation/recipe/transfer_queue/main_ppo.py
function main (line 42) | def main(config):
function run_ppo (line 52) | def run_ppo(config, task_runner_class=None) -> None:
class TaskRunner (line 111) | class TaskRunner(MainTaskRunner):
method run (line 112) | def run(self, config):
FILE: verl_distillation/recipe/transfer_queue/ray_trainer.py
class ResourcePoolManager (line 98) | class ResourcePoolManager:
method create_resource_pool (line 107) | def create_resource_pool(self):
method get_resource_pool (line 127) | def get_resource_pool(self, role: Role) -> RayResourcePool:
method get_n_gpus (line 131) | def get_n_gpus(self) -> int:
method _check_resource_available (line 135) | def _check_resource_available(self):
function compute_reward_decorated (line 155) | def compute_reward_decorated(data, reward_fn):
function compute_reward_async_decorated (line 160) | def compute_reward_async_decorated(data, reward_fn):
function apply_kl_penalty (line 165) | def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLCont...
function compute_response_mask (line 206) | def compute_response_mask(batch_meta: BatchMeta, data_system_client):
function compute_advantage (line 233) | def compute_advantage(
function compute_data_metrics_decorated (line 308) | def compute_data_metrics_decorated(batch, use_critic: bool = True):
function compute_timing_metrics_decorated (line 313) | def compute_timing_metrics_decorated(batch, timing_raw: dict[str, float]...
function compute_throughout_metrics_decorated (line 318) | def compute_throughout_metrics_decorated(batch, timing_raw: dict[str, fl...
function calculate_debug_metrics_decorated (line 323) | def calculate_debug_metrics_decorated(data):
function compute_val_reward_decorated (line 330) | def compute_val_reward_decorated(reward_fn, data, return_dict):
class RayPPOTrainer (line 334) | class RayPPOTrainer:
method __init__ (line 344) | def __init__(
method _initialize_train_data_system (line 422) | def _initialize_train_data_system(self, global_batch_size, num_n_sampl...
method _initialize_val_data_system (line 471) | def _initialize_val_data_system(self, global_batch_size, num_n_samples...
method _create_dataloader (line 520) | def _create_dataloader(self, train_dataset, val_dataset, collate_fn, t...
method _dump_generations (line 595) | def _dump_generations(self, inputs, outputs, gts, scores, reward_extra...
method _log_rollout_data (line 623) | def _log_rollout_data(
method _maybe_log_val_generations (line 659) | def _maybe_log_val_generations(self, inputs, outputs, scores):
method _get_gen_batch (line 683) | def _get_gen_batch(self, batch: DataProto) -> DataProto:
method _validate (line 700) | def _validate(self):
method init_workers (line 908) | def init_workers(self):
method _save_checkpoint (line 1030) | def _save_checkpoint(self):
method _load_checkpoint (line 1088) | def _load_checkpoint(self):
method _start_profiling (line 1145) | def _start_profiling(self, do_profile: bool) -> None:
method _stop_profiling (line 1156) | def _stop_profiling(self, do_profile: bool) -> None:
method _balance_batch (line 1167) | def _balance_batch(self, batch: BatchMeta, data_system_client, metrics...
method repeat_dict (line 1187) | def repeat_dict(
method dict_to_tensordict (line 1228) | def dict_to_tensordict(cls, data: dict[str, torch.Tensor | np.ndarray]...
method fit (line 1257) | def fit(self):
FILE: verl_distillation/scripts/converter_hf_to_mcore.py
function _init_args (line 49) | def _init_args():
function test_conversion (line 69) | def test_conversion(megatron_model_provider, tfconfig, output_path, model):
function convert_checkpoint_from_transformers_to_megatron (line 118) | def convert_checkpoint_from_transformers_to_megatron(
function safe_copy (line 193) | def safe_copy(
function convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl (line 207) | def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel,...
function convert_checkpoint_from_transformers_to_megatron_dpskv3 (line 314) | def convert_checkpoint_from_transformers_to_megatron_dpskv3(
function noop_context (line 403) | def noop_context() -> Any:
function support_distributed_convert (line 407) | def support_distributed_convert(hf_config: AutoConfig) -> bool:
function convert_hf_to_mcore (line 414) | def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initializati...
FILE: verl_distillation/scripts/diagnose.py
function test_connection (line 50) | def test_connection(name, url, timeout=10):
function check_python (line 70) | def check_python():
function check_pip (line 78) | def check_pip():
function _get_current_git_commit (line 89) | def _get_current_git_commit():
function check_verl (line 101) | def check_verl():
function check_os (line 126) | def check_os():
function check_hardware (line 135) | def check_hardware():
function check_network (line 151) | def check_network(args):
function check_environment (line 170) | def check_environment():
function check_pip_package_versions (line 177) | def check_pip_package_versions():
function check_cuda_versions (line 187) | def check_cuda_versions():
function _get_cpu_memory (line 208) | def _get_cpu_memory():
function _get_gpu_info (line 216) | def _get_gpu_info():
function _get_system_info (line 244) | def _get_system_info():
function check_system_info (line 253) | def check_system_info():
function parse_args (line 263) | def parse_args():
FILE: verl_distillation/scripts/init_random_model.py
function _init_args (line 37) | def _init_args():
function check_output_path (line 46) | def check_output_path(output_path: str):
function check_configs (line 55) | def check_configs(original_config: dict[str, Any], new_config: dict[str,...
function init_random_model (line 72) | def init_random_model(hf_model_path, new_config_path, output_path):
FILE: verl_distillation/scripts/legacy_model_merger.py
class ModelMergerConfig (line 75) | class ModelMergerConfig:
method __post_init__ (line 89) | def __post_init__(self):
class BaseModelMerger (line 97) | class BaseModelMerger(ABC):
method __init__ (line 98) | def __init__(self, config: ModelMergerConfig):
method get_transformers_auto_model_class (line 115) | def get_transformers_auto_model_class(self):
method patch_model_generation_config (line 139) | def patch_model_generation_config(self, model):
method save_lora_adapter (line 155) | def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]):
method save_hf_model_and_tokenizer (line 212) | def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tens...
method upload_to_huggingface (line 241) | def upload_to_huggingface(self):
method merge_and_save (line 249) | def merge_and_save(self):
class FSDPModelMerger (line 253) | class FSDPModelMerger(BaseModelMerger):
method _get_world_size (line 254) | def _get_world_size(self) -> int:
method _load_rank_zero_state_dict (line 264) | def _load_rank_zero_state_dict(self, world_size: int) -> dict:
method _extract_device_mesh_info (line 271) | def _extract_device_mesh_info(self, state_dict: dict, world_size: int)...
method _calculate_shard_configuration (line 291) | def _calculate_shard_configuration(
method _merge_by_placement (line 307) | def _merge_by_placement(self, tensors: list[torch.Tensor], placement: ...
method _load_and_merge_state_dicts (line 318) | def _load_and_merge_state_dicts(
method merge_and_save (line 381) | def merge_and_save(self):
method _test_state_dict (line 404) | def _test_state_dict(self, state_dict: dict[str, torch.Tensor]):
class MegatronModelMerger (line 438) | class MegatronModelMerger(BaseModelMerger):
method __init__ (line 439) | def __init__(self, config: ModelMergerConfig):
method _get_tp_pp_rank_from_sharded_dir (line 482) | def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[...
method _check_megatron_checkpoint_path (line 496) | def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[li...
method _merge_across_tp (line 511) | def _merge_across_tp(
method _load_state_dicts (line 567) | def _load_state_dicts(
method _check_megatron_state_key (line 585) | def _check_megatron_state_key(self, key: str) -> bool:
method _merge_state_dicts (line 609) | def _merge_state_dicts(
method merge_and_save (line 661) | def merge_and_save(self):
method _test_state_dict (line 683) | def _test_state_dict(self, state_dict: dict[str, torch.Tensor]):
method _replace_name (line 704) | def _replace_name(self, megatron_name: str, name_mapping: dict[str, st...
function main (line 716) | def main():
FILE: verl_distillation/scripts/print_cfg.py
function main (line 21) | def main(config):
FILE: verl_distillation/scripts/rollout_viewer.py
function check_textual_version (line 42) | def check_textual_version():
function load_path (line 54) | async def load_path(p: Path, data: dict, mask_strs: str, idx: int, pbar):
function load_dir (line 74) | async def load_dir(path: Path, data: dict[int, dict], pbar, mask_strs: s...
class Highlighter (line 83) | class Highlighter(ReprHighlighter):
function center_word_with_equals_exactly (line 90) | def center_word_with_equals_exactly(word: str, total_length: int, char: ...
function highlight_keyword (line 100) | def highlight_keyword(content: str, keyword: Optional[str]):
class JsonLineViewer (line 129) | class JsonLineViewer(App):
method __init__ (line 175) | def __init__(self, step_num: int, data: dict[int, dict], pbar):
method compose (line 200) | def compose(self) -> ComposeResult:
method on_mount (line 250) | async def on_mount(self) -> None:
method update_result_options (line 268) | def update_result_options(self, offset: int = 0, sort_desc: Optional[b...
method update_content (line 292) | async def update_content(self, search_keyword: Optional[str] = None):
method on_reqid_submitted (line 332) | async def on_reqid_submitted(self, event: Input.Submitted) -> None:
method _update_fields_select (line 373) | def _update_fields_select(self, keys):
method step_changed (line 395) | async def step_changed(self, event):
method sample_changed (line 401) | async def sample_changed(self, event):
method sort_changed (line 407) | async def sort_changed(self, event):
method fields_changed (line 413) | async def fields_changed(self, event):
method fields_all_changed (line 417) | async def fields_all_changed(self, event):
method action_focus_previous (line 424) | def action_focus_previous(self):
method action_focus_next (line 427) | def action_focus_next(self):
method action_next_step (line 430) | async def action_next_step(self) -> None:
method action_next_sample (line 438) | async def action_next_sample(self) -> None:
method action_previous_step (line 446) | async def action_previous_step(self) -> None:
method action_previous_sample (line 454) | async def action_previous_sample(self) -> None:
method action_swith_render (line 462) | async def action_swith_render(self):
method action_toggle_search (line 466) | def action_toggle_search(self) -> None:
method action_cancel_search (line 469) | async def action_cancel_search(self) -> None:
method _clear_search (line 474) | async def _clear_search(self):
method on_search_submitted (line 480) | async def on_search_submitted(self, event: Input.Submitted) -> None:
method action_next_search (line 507) | async def action_next_search(self) -> None:
method action_page_up (line 521) | def action_page_up(self):
method action_page_down (line 524) | def action_page_down(self):
method action_page_home (line 527) | def action_page_home(self):
method action_page_end (line 530) | def action_page_end(self):
function _run (line 534) | async def _run(path: Path, mask_str: str):
function run (line 556) | def run(
FILE: verl_distillation/tests/experimental/agent_loop/agent_utils.py
function init_agent_loop_manager (line 25) | def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | Ra...
FILE: verl_distillation/tests/experimental/agent_loop/test_agent_loop_reward.py
function test_agent_loop_compute_score (line 29) | def test_agent_loop_compute_score():
FILE: verl_distillation/tests/experimental/agent_loop/test_agent_loop_reward_model.py
function test_agent_loop_compute_score_with_model (line 29) | def test_agent_loop_compute_score_with_model():
FILE: verl_distillation/tests/experimental/agent_loop/test_basic_agent_loop.py
function init_config (line 35) | def init_config() -> DictConfig:
function test_single_turn (line 67) | def test_single_turn(init_config):
class WeatherTool (line 131) | class WeatherTool(BaseTool):
method get_current_temperature (line 132) | def get_current_temperature(self, location: str, unit: str = "celsius"):
method get_openai_tool_schema (line 149) | def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
method execute (line 153) | async def execute(self, instance_id: str, parameters: dict[str, Any], ...
class WeatherToolWithData (line 161) | class WeatherToolWithData(BaseTool):
method get_openai_tool_schema (line 162) | def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
method get_temperature_date (line 166) | def get_temperature_date(self, location: str, date: str, unit: str = "...
method execute (line 185) | async def execute(self, instance_id: str, parameters: dict[str, Any], ...
function test_tool_agent (line 193) | def test_tool_agent(init_config):
function test_tool_agent_with_interaction (line 307) | def test_tool_agent_with_interaction(init_config):
function test_get_trajectory_info (line 432) | async def test_get_trajectory_info():
FILE: verl_distillation/tests/experimental/agent_loop/test_gpt_oss_tool_parser.py
function test_gpt_oss_tool_parser (line 22) | async def test_gpt_oss_tool_parser():
FILE: verl_distillation/tests/experimental/agent_loop/test_multi_modal.py
function init_config (line 33) | def init_config() -> DictConfig:
class ImageGeneratorTool (line 61) | class ImageGeneratorTool(BaseTool):
method generate_image (line 62) | def generate_image(self, description: str, size: str = "256x256"):
method get_openai_tool_schema (line 99) | def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
method execute (line 103) | async def execute(self, instance_id: str, parameters: dict[str, Any], ...
function test_multimodal_tool_agent (line 112) | def test_multimodal_tool_agent(init_config):
function test_multimodal_single_turn_agent (line 249) | def test_multimodal_single_turn_agent(init_config):
function test_multimodal_partial_single_turn_agent (line 381) | def test_multimodal_partial_single_turn_agent(init_config):
FILE: verl_distillation/tests/experimental/agent_loop/test_standalone_rollout.py
function init_config (line 27) | def init_config() -> DictConfig:
function test_standalone_rollout (line 46) | async def test_standalone_rollout(init_config, tp_size):
function test_hybrid_rollout_with_ep (line 98) | def test_hybrid_rollout_with_ep(init_config):
FILE: verl_distillation/tests/experimental/reward/reward_fn.py
function chat_complete (line 41) | async def chat_complete(router_address: str, chat_complete_request: dict):
function compute_score_gsm8k (line 56) | async def compute_score_gsm8k(
FILE: verl_distillation/tests/experimental/reward/test_agent_loop_reward_manager.py
function test_agent_loop_reward_manager (line 27) | def test_agent_loop_reward_manager():
FILE: verl_distillation/tests/experimental/reward/test_reward_model.py
function create_data_samples (line 41) | def create_data_samples() -> DataProto:
function test_reward_model_manager (line 70) | def test_reward_model_manager():
FILE: verl_distillation/tests/interactions/test_gsm8k_interaction.py
class TestGsm8kInteraction (line 24) | class TestGsm8kInteraction:
method setup_method (line 27) | def setup_method(self):
method test_init (line 32) | def test_init(self):
method test_start_interaction_with_instance_id (line 39) | async def test_start_interaction_with_instance_id(self):
method test_start_interaction_without_instance_id (line 53) | async def test_start_interaction_without_instance_id(self):
method test_start_interaction_without_ground_truth (line 65) | async def test_start_interaction_without_ground_truth(self):
method test_generate_response_correct_answer_with_prefix (line 75) | async def test_generate_response_correct_answer_with_prefix(self):
method test_generate_response_correct_answer_without_prefix (line 97) | async def test_generate_response_correct_answer_without_prefix(self):
method test_generate_response_incorrect_answer (line 118) | async def test_generate_response_incorrect_answer(self):
method test_generate_response_multiple_messages (line 139) | async def test_generate_response_multiple_messages(self):
method test_generate_response_no_assistant_message (line 164) | async def test_generate_response_no_assistant_message(self):
method test_calculate_score_direct_call (line 183) | async def test_calculate_score_direct_call(self):
method test_calculate_score_with_kwargs (line 201) | async def test_calculate_score_with_kwargs(self):
method test_finalize_interaction (line 219) | async def test_finalize_interaction(self):
method test_finalize_interaction_with_kwargs (line 234) | async def test_finalize_interaction_with_kwargs(self):
method test_finalize_nonexistent_interaction (line 249) | async def test_finalize_nonexistent_interaction(self):
method test_full_interaction_workflow_correct (line 258) | async def test_full_interaction_workflow_correct(self):
method test_full_interaction_workflow_incorrect (line 281) | async def test_full_interaction_workflow_incorrect(self):
method test_multiple_concurrent_interactions (line 316) | async def test_multiple_concurrent_interactions(self):
method test_edge_case_empty_messages (line 349) | async def test_edge_case_empty_messages(self):
method test_edge_case_message_without_content (line 369) | async def test_edge_case_message_without_content(self):
method test_inheritance_from_base_interaction (line 390) | def test_inheritance_from_base_interaction(self):
method test_name_attribute_initialization (line 408) | def test_name_attribute_initialization(self):
FILE: verl_distillation/tests/interactions/test_interaction_registry.py
class TestInteractionRegistry (line 30) | class TestInteractionRegistry:
method test_get_interaction_class (line 31) | def test_get_interaction_class(self):
method test_initialize_single_interaction_from_config (line 41) | def test_initialize_single_interaction_from_config(self):
method test_initialize_multiple_interactions_from_config (line 69) | def test_initialize_multiple_interactions_from_config(self):
method test_initialize_interaction_without_explicit_name (line 111) | def test_initialize_interaction_without_explicit_name(self):
method test_initialize_empty_config (line 132) | def test_initialize_empty_config(self):
method test_invalid_class_name (line 146) | def test_invalid_class_name(self):
method test_duplicate_interaction_names (line 162) | def test_duplicate_interaction_names(self):
method test_auto_name_generation_edge_cases (line 185) | def test_auto_name_generation_edge_cases(self):
FILE: verl_distillation/tests/models/test_engine.py
function test_actor_engine (line 48) | def test_actor_engine(strategy):
function create_model (line 160) | def create_model():
function test_critic_engine (line 173) | def test_critic_engine(strategy):
function create_actor_model (line 275) | def create_actor_model(tmp_path, config):
function _worker (line 283) | def _worker(rank: int, world_size: int, rendezvous_file: str, strategy: ...
function test_per_tensor_generator (line 353) | def test_per_tensor_generator(world_size, tmp_path, config, strategy):
FILE: verl_distillation/tests/models/test_transformer.py
function test_hf_casual_models (line 41) | def test_hf_casual_models():
function test_hf_value_models (line 111) | def test_hf_value_models():
function test_attn_implementation_override (line 166) | def test_attn_implementation_override():
function test_fsdp_worker_attn_implementation_integration (line 201) | def test_fsdp_worker_attn_implementation_integration():
FILE: verl_distillation/tests/models/test_transformers_ulysses.py
class SequenceParallelConfig (line 44) | class SequenceParallelConfig:
function test_configs (line 50) | def test_configs():
function sync_model_parameters_global (line 87) | def sync_model_parameters_global(layer):
function test_hf_casual_fwd_bwd (line 94) | def test_hf_casual_fwd_bwd(test_config):
function _hf_casual_fwd (line 107) | def _hf_casual_fwd(config, sp_size, dp_size):
function _hf_casual_fwd_bwd (line 186) | def _hf_casual_fwd_bwd(config, sp_size, dp_size):
FILE: verl_distillation/tests/single_controller/base/test_decorator.py
function reset_dispatch_registry (line 29) | def reset_dispatch_registry():
function test_register_new_dispatch_mode (line 38) | def test_register_new_dispatch_mode(reset_dispatch_registry):
function test_update_existing_dispatch_mode (line 60) | def test_update_existing_dispatch_mode(reset_dispatch_registry):
FILE: verl_distillation/tests/single_controller/check_worker_alive/main.py
class TestActor (line 27) | class TestActor(Worker):
method __init__ (line 28) | def __init__(self) -> None:
method foo (line 32) | def foo(self, wait_time):
FILE: verl_distillation/tests/single_controller/detached_worker/client.py
function compute_position_id_with_mask (line 27) | def compute_position_id_with_mask(mask):
FILE: verl_distillation/tests/single_controller/detached_worker/server.py
class Trainer (line 44) | class Trainer(Worker):
method __init__ (line 45) | def __init__(self):
method init_model (line 75) | def init_model(self):
method train_model (line 118) | def train_model(self, data: DataProto) -> DataProto:
FILE: verl_distillation/tests/single_controller/test_auto_padding_on_cpu.py
class Actor (line 30) | class Actor(Worker):
method __init__ (line 31) | def __init__(self) -> None:
method add (line 35) | def add(self, data: DataProto):
function test_auto_padding (line 40) | def test_auto_padding():
FILE: verl_distillation/tests/single_controller/test_colocated_workers.py
class Actor (line 29) | class Actor(Worker):
method __init__ (line 30) | def __init__(self) -> None:
method add (line 34) | def add(self, data: DataProto):
class Critic (line 40) | class Critic(Worker):
method __init__ (line 41) | def __init__(self, config) -> None:
method sub (line 46) | async def sub(self, data: DataProto):
function test_colocated_workers (line 51) | def test_colocated_workers():
FILE: verl_distillation/tests/single_controller/test_colocated_workers_fused.py
class Actor (line 29) | class Actor(Worker):
method __init__ (line 30) | def __init__(self) -> None:
method add (line 34) | def add(self, data: DataProto):
class Critic (line 40) | class Critic(Worker):
method __init__ (line 41) | def __init__(self, config) -> None:
method sub (line 46) | def sub(self, data: DataProto):
function test_colocated_workers_fused (line 51) | def test_colocated_workers_fused():
FILE: verl_distillation/tests/single_controller/test_data_transfer.py
class DummyWorker (line 32) | class DummyWorker(Worker):
method __init__ (line 33) | def __init__(self):
method do_nothing (line 38) | def do_nothing(self, data):
function test_data_transfer (line 46) | def test_data_transfer():
FILE: verl_distillation/tests/single_controller/test_decorator_on_cpu.py
function ray_init_shutdown (line 31) | def ray_init_shutdown():
class DecoratorTestWorker (line 39) | class DecoratorTestWorker(Worker):
method __init__ (line 40) | def __init__(self, initial_value=0):
method dp_compute (line 48) | def dp_compute(self, data: DataProto) -> DataProto:
method async_dp_compute (line 56) | async def async_dp_compute(self, data: DataProto) -> DataProto:
function test_decorator_dp_compute (line 65) | def test_decorator_dp_compute(ray_init_shutdown):
function test_decorator_async_function (line 101) | def test_decorator_async_function(ray_init_shutdown):
FILE: verl_distillation/tests/single_controller/test_device_mesh_register.py
class TestActor (line 25) | class TestActor(Worker):
method __init__ (line 26) | def __init__(self):
method generate_data_proto (line 52) | def generate_data_proto(self, data: DataProto):
method train_data_proto (line 59) | def train_data_proto(self, data: DataProto):
function test_dist_global_info_wg (line 69) | def test_dist_global_info_wg():
FILE: verl_distillation/tests/single_controller/test_driverfunc_to_worker.py
class ModelActor (line 31) | class ModelActor(Worker):
method __init__ (line 32) | def __init__(self):
class HackSelf (line 36) | class HackSelf:
method __init__ (line 37) | def __init__(self):
function get_aux_metrics (line 41) | def get_aux_metrics(self, test_proto):
function test (line 54) | def test():
FILE: verl_distillation/tests/single_controller/test_fused_workers_on_cpu.py
class Actor (line 28) | class Actor(Worker):
method __init__ (line 29) | def __init__(self) -> None:
method add (line 33) | def add(self, x):
class Critic (line 39) | class Critic(Worker):
method __init__ (line 40) | def __init__(self, val) -> None:
method sub (line 45) | def sub(self, x):
class HybridWorker (line 57) | class HybridWorker(FusedBaseClass):
method foo (line 59) | def foo(self, x):
function test_fused_workers (line 63) | def test_fused_workers():
FILE: verl_distillation/tests/single_controller/test_high_level_scheduling_api.py
class TestActor (line 24) | class TestActor(Worker):
method __init__ (line 26) | def __init__(self, cuda_visible_devices=None) -> None:
method get_node_id (line 29) | def get_node_id(self):
function test (line 33) | def test():
FILE: verl_distillation/tests/single_controller/test_nested_worker.py
class TestActor (line 23) | class TestActor(Worker):
method __init__ (line 25) | def __init__(self, x) -> None:
method get (line 30) | def get(self):
class TestHighLevelActor (line 34) | class TestHighLevelActor(Worker):
method __init__ (line 35) | def __init__(self, x=None) -> None:
method get (line 40) | def get(self):
function test_nested_worker (line 44) | def test_nested_worker():
FILE: verl_distillation/tests/single_controller/test_ray_collectives.py
class Actor (line 33) | class Actor(Worker):
method init (line 35) | def init(self):
method send_tensors (line 41) | def send_tensors(self):
class Rollout (line 47) | class Rollout(Worker):
method init (line 49) | def init(self):
method receive_tensors (line 59) | def receive_tensors(self):
method get_tensors (line 67) | def get_tensors(self):
function test_ray_collective_group (line 71) | def test_ray_collective_group():
FILE: verl_distillation/tests/single_controller/test_ray_local_envs_on_cpu.py
class TestActor (line 27) | class TestActor(Worker):
method __init__ (line 28) | def __init__(self) -> None:
method getenv (line 31) | def getenv(self, key):
function test_basics (line 36) | def test_basics():
function test_customized_worker_env (line 53) | def test_customized_worker_env():
FILE: verl_distillation/tests/single_controller/test_ray_utils_on_cpu.py
function init_ray (line 23) | def init_ray():
function test_parallel_put_basic (line 29) | def test_parallel_put_basic(init_ray):
function test_parallel_put_empty (line 37) | def test_parallel_put_empty(init_ray):
function test_parallel_put_workers (line 43) | def test_parallel_put_workers(init_ray):
FILE: verl_distillation/tests/single_controller/test_rvdz.py
class TestWorker (line 19) | class TestWorker:
method __init__ (line 20) | def __init__(self, rank, world_size, group_name):
method init (line 26) | def init(self):
method test (line 31) | def test(self):
function test_rvdz (line 37) | def test_rvdz():
FILE: verl_distillation/tests/single_controller/test_worker_group_basics.py
function two_to_all_dispatch_fn (line 26) | def two_to_all_dispatch_fn(worker_group, *args, **kwargs):
class TestActor (line 42) | class TestActor(Worker):
method __init__ (line 44) | def __init__(self, x) -> None:
method foo (line 48) | def foo(self, y):
method foo_rank_zero (line 52) | def foo_rank_zero(self, x, y):
method foo_one_to_all (line 56) | def foo_one_to_all(self, x, y):
method foo_all_to_all (line 60) | def foo_all_to_all(self, x, y):
method foo_custom (line 64) | def foo_custom(self, x, y):
function remote_call_wg (line 69) | def remote_call_wg(worker_names):
function add_one (line 85) | def add_one(data):
function test_basics (line 92) | def test_basics():
FILE: verl_distillation/tests/single_controller/test_worker_group_torch.py
class TestAllGatherActor (line 29) | class TestAllGatherActor(Worker):
method __init__ (line 30) | def __init__(self, size) -> None:
method init (line 34) | def init(self):
method all_gather (line 39) | def all_gather(self):
class TestAllGatherActorV2 (line 49) | class TestAllGatherActorV2(Worker):
method __init__ (line 50) | def __init__(self, size) -> None:
method all_gather (line 58) | def all_gather(self):
function test_all_gather_torch (line 67) | def test_all_gather_torch():
function test_all_gather_torch_v2 (line 91) | def test_all_gather_torch_v2():
FILE: verl_distillation/tests/special_distributed/test_fsdp_ckpt.py
function create_random_input_ids (line 30) | def create_random_input_ids(batch_size, seq_len, vocab_size):
function test_fsdp_ckpt (line 47) | def test_fsdp_ckpt(strategy="fsdp"):
FILE: verl_distillation/tests/special_distributed/test_mcore_config_converter.py
function check_config_converter_results (line 36) | def check_config_converter_results(tf_config: TransformerConfig | MLATra...
function modify_hf_config (line 67) | def modify_hf_config(name: str, hf_config: PretrainedConfig):
function test_mcore_config_converter (line 74) | def test_mcore_config_converter():
FILE: verl_distillation/tests/special_distributed/test_tensor_dict.py
function test_all_gather_data_proto (line 27) | def test_all_gather_data_proto():
function test_vocab_parallel_entropy (line 58) | def test_vocab_parallel_entropy():
FILE: verl_distillation/tests/special_e2e/check_custom_rwd_fn.py
function check_congratulations_in_file (line 18) | def check_congratulations_in_file(output_file):
FILE: verl_distillation/tests/special_e2e/check_results.py
function extract_reward_from_line (line 20) | def extract_reward_from_line(line):
FILE: verl_distillation/tests/special_e2e/envs/digit_completion/task.py
class DigitCompletion (line 19) | class DigitCompletion:
method __init__ (line 35) | def __init__(self, max_number: int, max_diff: int, max_num_in_response...
method __str__ (line 56) | def __str__(self):
method get_state (line 63) | def get_state(self):
method set_state (line 66) | def set_state(self, state):
method prompt_length (line 71) | def prompt_length(self):
method response_length (line 75) | def response_length(self):
method add (line 80) | def add(self, a, b):
method get_all_prompts (line 83) | def get_all_prompts(self):
method sample_str_prompts (line 93) | def sample_str_prompts(self):
method sample_batch_str_prompts (line 102) | def sample_batch_str_prompts(self, batch_size):
function compute_attention_mask (line 109) | def compute_attention_mask(prompts, pad_token_id):
function compute_position_id_with_mask (line 115) | def compute_position_id_with_mask(mask):
function generate_ground_truth_response (line 119) | def generate_ground_truth_response(prompt: str):
function compute_reward (line 139) | def compute_reward(prompt: str, response: str, sequence_reward=1.0):
FILE: verl_distillation/tests/special_e2e/envs/digit_completion/tokenizer.py
class CharTokenizer (line 29) | class CharTokenizer(PreTrainedTokenizer):
method __init__ (line 30) | def __init__(self, characters: Sequence[str], model_max_length: int, c...
method vocab_size (line 83) | def vocab_size(self) -> int:
method get_vocab (line 86) | def get_vocab(self):
method _tokenize (line 89) | def _tokenize(self, text: str) -> list[str]:
method _convert_token_to_id (line 92) | def _convert_token_to_id(self, token: str) -> int:
method _convert_id_to_token (line 95) | def _convert_id_to_token(self, index: int) -> str:
method convert_tokens_to_string (line 98) | def convert_tokens_to_string(self, tokens):
method build_inputs_with_special_tokens (line 101) | def build_inputs_with_special_tokens(
method get_special_tokens_mask (line 111) | def get_special_tokens_mask(
method get_config (line 129) | def get_config(self) -> dict:
method from_config (line 137) | def from_config(cls, config: dict):
method save_pretrained (line 144) | def save_pretrained(self, save_directory: str | os.PathLike, **kwargs):
method from_pretrained (line 151) | def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs):
FILE: verl_distillation/tests/special_e2e/sft/compare_sft_engine_results.py
function get_result (line 21) | def get_result(file):
function compare_results (line 31) | def compare_results(golden_results, other_result):
FILE: verl_distillation/tests/special_e2e/sft/test_sp_loss_match.py
function test_trainer_forward_consistency (line 24) | def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_step...
function create_trainer (line 90) | def create_trainer(config):
function main (line 132) | def main(config):
function hydra_entry (line 147) | def hydra_entry(cfg: DictConfig) -> None:
FILE: verl_distillation/tests/special_sanity/check_api_docs.py
function iter_submodules (line 57) | def iter_submodules(root: ModuleType) -> Iterable[ModuleType]:
function names_missing_doc (line 72) | def names_missing_doc(mod: ModuleType) -> list[str]:
function check_module (line 92) | def check_module(qualname: str) -> list[str]:
function autodiscover_packages (line 106) | def autodiscover_packages() -> list[str]:
function main (line 115) | def main() -> None:
FILE: verl_distillation/tests/special_sanity/check_docs_time_info.py
function is_allowed (line 41) | def is_allowed(path: Path) -> bool:
function main (line 52) | def main():
FILE: verl_distillation/tests/special_sanity/check_docstrings.py
class DocstringChecker (line 25) | class DocstringChecker(ast.NodeVisitor):
method __init__ (line 28) | def __init__(self, filename: str):
method visit_FunctionDef (line 34) | def visit_FunctionDef(self, node: ast.FunctionDef):
method visit_AsyncFunctionDef (line 45) | def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
method visit_ClassDef (line 56) | def visit_ClassDef(self, node: ast.ClassDef):
method _has_docstring (line 67) | def _has_docstring(self, node) -> bool:
function check_file_docstrings (line 72) | def check_file_docstrings(filepath: str) -> list[tuple[str, str, int]]:
function main (line 88) | def main():
FILE: verl_distillation/tests/special_sanity/check_license.py
function get_py_files (line 41) | def get_py_files(path_arg: Path) -> Iterable[Path]:
FILE: verl_distillation/tests/special_sanity/check_pr_description.py
class TemplateFileError (line 24) | class TemplateFileError(Exception):
class PRBodyLoadError (line 28) | class PRBodyLoadError(Exception):
class PRDescriptionError (line 32) | class PRDescriptionError(Exception):
function load_template (line 40) | def load_template(path):
function load_pr_body (line 58) | def load_pr_body(event_path):
function check_pr_description (line 67) | def check_pr_description(body, template_lines):
function main (line 84) | def main():
FILE: verl_distillation/tests/special_sanity/test_config_docs.py
function validate_yaml_format (line 19) | def validate_yaml_format(yaml_lines):
function test_trainer_config_doc (line 60) | def test_trainer_config_doc():
FILE: verl_distillation/tests/special_sanity/test_import.py
function test_import (line 16) | def test_import():
function test_single_controller_import (line 22) | def test_single_controller_import():
FILE: verl_distillation/tests/special_sanity/type_coverage_check.py
function get_changed_files (line 27) | def get_changed_files() -> list[Path]:
function get_changed_lines (line 34) | def get_changed_lines(file_path: Path) -> set[int]:
function should_check_type (line 61) | def should_check_type(arg_name: str) -> bool:
function has_type_annotations (line 69) | def has_type_annotations(node: ast.AST, debug: bool = False) -> int:
function check_file (line 85) | def check_file(
function main (line 114) | def main() -> None:
FILE: verl_distillation/tests/special_sanity/validate_imported_docs.py
function _parse_args (line 32) | def _parse_args() -> argparse.Namespace:
function _import_attr (line 57) | def _import_attr(module_name: str, attr_name: str):
function _check_file (line 63) | def _check_file(py_file: pathlib.Path, project_root: pathlib.Path, allow...
function main (line 110) | def main() -> None:
FILE: verl_distillation/tests/special_sanity/validate_structure.py
function discover_allowed_modules (line 39) | def discover_allowed_modules(impl_root: Path, extra: list[str]) -> set[s...
function find_violations (line 46) | def find_violations(tests_root: Path, allowed: set[str], allowed_files: ...
function main (line 66) | def main() -> None:
FILE: verl_distillation/tests/special_standalone/test_memory_buffers.py
function test_memory_buffers (line 26) | def test_memory_buffers():
FILE: verl_distillation/tests/test_base_config_on_cpu.py
function base_config_mock (line 21) | def base_config_mock():
function test_getitem_success (line 28) | def test_getitem_success(base_config_mock):
function test_getitem_nonexistent_attribute (line 33) | def test_getitem_nonexistent_attribute(base_config_mock):
function test_getitem_invalid_key_type (line 39) | def test_getitem_invalid_key_type(base_config_mock):
FILE: verl_distillation/tests/test_protocol_on_cpu.py
function test_union_tensor_dict (line 36) | def test_union_tensor_dict():
function test_union_numpy_dict (line 51) | def test_union_numpy_dict():
function test_tensor_dict_constructor (line 141) | def test_tensor_dict_constructor():
function test_tensor_dict_make_iterator (line 155) | def test_tensor_dict_make_iterator():
function test_reorder (line 184) | def test_reorder():
function test_chunk_concat (line 195) | def test_chunk_concat():
function test_concat_metrics_from_multiple_workers (line 219) | def test_concat_metrics_from_multiple_workers():
function test_concat_with_empty_and_non_list_meta_info (line 249) | def test_concat_with_empty_and_non_list_meta_info():
function test_concat_first_worker_missing_metrics (line 272) | def test_concat_first_worker_missing_metrics():
function test_concat_non_list_metrics (line 295) | def test_concat_non_list_metrics():
function test_concat_merge_different_non_metric_keys (line 315) | def test_concat_merge_different_non_metric_keys():
function test_concat_conflicting_non_metric_keys (line 339) | def test_concat_conflicting_non_metric_keys():
function test_pop (line 357) | def test_pop():
function test_repeat (line 370) | def test_repeat():
function test_dataproto_pad_unpad (line 395) | def test_dataproto_pad_unpad():
function test_dataproto_fold_unfold (line 447) | def test_dataproto_fold_unfold():
function test_torch_save_data_proto (line 470) | def test_torch_save_data_proto():
function test_len (line 486) | def test_len():
function test_dataproto_index (line 506) | def test_dataproto_index():
function test_old_vs_new_from_single_dict (line 570) | def test_old_vs_new_from_single_dict():
function test_dataproto_no_batch (line 607) | def test_dataproto_no_batch():
function test_sample_level_repeat (line 617) | def test_sample_level_repeat():
function test_dataproto_unfold_column_chunks (line 642) | def test_dataproto_unfold_column_chunks():
function test_dataproto_chunk_after_index (line 708) | def test_dataproto_chunk_after_index():
function test_to_tensordict (line 754) | def test_to_tensordict():
function test_from_tensordict (line 768) | def test_from_tensordict():
function test_serialize_deserialize_single_tensor (line 782) | def test_serialize_deserialize_single_tensor():
function test_serialize_deserialize_tensordict_regular_tensors (line 799) | def test_serialize_deserialize_tensordict_regular_tensors():
function test_serialize_deserialize_tensordict_nested_tensors (line 828) | def test_serialize_deserialize_tensordict_nested_tensors():
function test_serialize_deserialize_tensordict_mixed_types (line 881) | def test_serialize_deserialize_tensordict_mixed_types():
function test_serialize_deserialize_tensordict_with_device (line 966) | def test_serialize_deserialize_tensordict_with_device():
FILE: verl_distillation/tests/test_protocol_v2_on_cpu.py
function test_union_tensor_dict (line 29) | def test_union_tensor_dict():
function test_tensor_dict_constructor (line 66) | def test_tensor_dict_constructor():
function test_index_select_tensor_dict (line 91) | def test_index_select_tensor_dict():
function test_tensordict_with_images (line 130) | def test_tensordict_with_images():
function test_tensordict_with_packing (line 158) | def test_tensordict_with_packing():
function test_tensordict_eq (line 184) | def test_tensordict_eq():
function test_tensor_dict_make_iterator (line 247) | def test_tensor_dict_make_iterator():
function test_reorder (line 279) | def test_reorder():
function test_chunk_concat (line 292) | def test_chunk_concat():
function test_pop (line 320) | def test_pop():
function test_repeat (line 334) | def test_repeat():
function test_dataproto_pad_unpad (line 359) | def test_dataproto_pad_unpad():
function test_torch_save_data_proto (line 410) | def test_torch_save_data_proto():
function test_len (line 428) | def test_len():
function test_dataproto_index (line 445) | def test_dataproto_index():
function test_select (line 505) | def test_select():
function test_dataproto_no_batch (line 518) | def test_dataproto_no_batch():
function test_sample_level_repeat (line 529) | def test_sample_level_repeat():
function test_dataproto_chunk_after_index (line 555) | def test_dataproto_chunk_after_index():
FILE: verl_distillation/tests/trainer/config/test_algo_config_on_cpu.py
class TestAlgoConfig (line 30) | class TestAlgoConfig(unittest.TestCase):
method setUp (line 33) | def setUp(self):
method test_dataclass_creation_from_dict (line 56) | def test_dataclass_creation_from_dict(self):
method test_dataclass_creation_from_omega_config (line 69) | def test_dataclass_creation_from_omega_config(self):
method test_nested_configs (line 77) | def test_nested_configs(self):
method test_default_values (line 92) | def test_default_values(self):
method test_get_method_backward_compatibility (line 105) | def test_get_method_backward_compatibility(self):
method test_post_init_nested_configs (line 117) | def test_post_init_nested_configs(self):
method test_config_init_from_yaml (line 127) | def test_config_init_from_yaml(self):
class TestAlgoCompute (line 140) | class TestAlgoCompute(unittest.TestCase):
method setUp (line 143) | def setUp(self):
method test_advantage_estimator_with_cfg (line 157) | def test_advantage_estimator_with_cfg(self):
method test_grpo_advantage_estimator_with_cfg (line 182) | def test_grpo_advantage_estimator_with_cfg(self):
FILE: verl_distillation/tests/trainer/config/test_legacy_config_on_cpu.py
class TestConfigComparison (line 35) | class TestConfigComparison(unittest.TestCase):
method _compare_configs_recursively (line 54) | def _compare_configs_recursively(
method test_ppo_trainer_config_matches_legacy (line 110) | def test_ppo_trainer_config_matches_legacy(self):
method test_ppo_megatron_trainer_config_matches_legacy (line 134) | def test_ppo_megatron_trainer_config_matches_legacy(self):
method test_load_component (line 156) | def test_load_component(self):
FILE: verl_distillation/tests/trainer/ppo/test_core_algos_on_cpu.py
function mock_test_fn (line 34) | def mock_test_fn():
class TestRegisterAdvEst (line 38) | class TestRegisterAdvEst(unittest.TestCase):
method setUp (line 39) | def setUp(self):
method tearDown (line 48) | def tearDown(self) -> None:
method test_register_new_function (line 52) | def test_register_new_function(self):
method test_register_with_enum (line 62) | def test_register_with_enum(self):
method test_duplicate_registration_same_function (line 76) | def test_duplicate_registration_same_function(self):
method test_duplicate_registration_different_function (line 83) | def test_duplicate_registration_different_function(self):
method test_decorator_preserves_function (line 96) | def test_decorator_preserves_function(self):
method test_multiple_registrations (line 105) | def test_multiple_registrations(self):
method test_get_adv_estimator_fn_valid_names (line 121) | def test_get_adv_estimator_fn_valid_names(self):
method test_get_adv_estimator_fn_invalid_name (line 131) | def test_get_adv_estimator_fn_invalid_name(self):
method test_get_adv_estimator_fn_case_sensitive (line 137) | def test_get_adv_estimator_fn_case_sensitive(self):
function test_multi_turn_compute_gae_advantage_return (line 143) | def test_multi_turn_compute_gae_advantage_return():
function _make_group_index (line 200) | def _make_group_index(batch_size: int, num_groups: int) -> np.ndarray:
function _rand_mask (line 214) | def _rand_mask(batch_size: int, seq_len: int) -> torch.Tensor:
function test_rloo_and_vectorized_equivalence (line 230) | def test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, ...
function test_grpo_and_vectorized_equivalence (line 270) | def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, ...
FILE: verl_distillation/tests/trainer/ppo/test_metric_utils_on_cpu.py
class TestReduceMetrics (line 37) | class TestReduceMetrics(unittest.TestCase):
method test_reduce_metrics_basic (line 40) | def test_reduce_metrics_basic(self):
method test_reduce_metrics_empty (line 51) | def test_reduce_metrics_empty(self):
method test_reduce_metrics_single_value (line 60) | def test_reduce_metrics_single_value(self):
class TestComputeDataMetrics (line 70) | class TestComputeDataMetrics(unittest.TestCase):
method setUp (line 73) | def setUp(self):
method test_compute_data_metrics_with_critic (line 98) | def test_compute_data_metrics_with_critic(self):
method test_compute_data_metrics_without_critic (line 116) | def test_compute_data_metrics_without_critic(self):
class TestComputeTimingMetrics (line 130) | class TestComputeTimingMetrics(unittest.TestCase):
method setUp (line 133) | def setUp(self):
method test_compute_timing_metrics (line 155) | def test_compute_timing_metrics(self, mock_compute_response_info):
class TestComputeThroughputMetrics (line 181) | class TestComputeThroughputMetrics(unittest.TestCase):
method setUp (line 184) | def setUp(self):
method test_compute_throughout_metrics (line 192) | def test_compute_throughout_metrics(self):
class TestBootstrapMetric (line 213) | class TestBootstrapMetric(unittest.TestCase):
method test_bootstrap_metric_basic (line 216) | def test_bootstrap_metric_basic(self):
method test_bootstrap_metric_empty (line 240) | def test_bootstrap_metric_empty(self):
class TestCalcMajVal (line 246) | class TestCalcMajVal(unittest.TestCase):
method test_calc_maj_val_basic (line 249) | def test_calc_maj_val_basic(self):
method test_calc_maj_val_tie (line 262) | def test_calc_maj_val_tie(self):
class TestProcessValidationMetrics (line 279) | class TestProcessValidationMetrics(unittest.TestCase):
method test_process_validation_metrics_basic (line 282) | def test_process_validation_metrics_basic(self):
method test_process_validation_metrics_with_pred (line 305) | def test_process_validation_metrics_with_pred(self):
FILE: verl_distillation/tests/trainer/ppo/test_rollout_is.py
function test_basic_rollout_is (line 37) | def test_basic_rollout_is():
function test_metrics_completeness (line 156) | def test_metrics_completeness():
function test_mismatch_metrics (line 216) | def test_mismatch_metrics():
function test_mask_mode (line 271) | def test_mask_mode():
FILE: verl_distillation/tests/trainer/ppo/test_rollout_is_integration.py
class TestRolloutISIntegration (line 24) | class TestRolloutISIntegration:
method sample_data (line 28) | def sample_data(self):
method config_with_rollout_is (line 42) | def config_with_rollout_is(self):
method test_policy_loss_with_rollout_is (line 56) | def test_policy_loss_with_rollout_is(self, sample_data, config_with_ro...
method test_rollout_is_weights_computation (line 93) | def test_rollout_is_weights_computation(self, sample_data):
method test_all_aggregation_levels (line 118) | def test_all_aggregation_levels(self, sample_data):
method test_both_bounding_modes (line 134) | def test_both_bounding_modes(self, sample_data):
method test_mismatch_metrics (line 151) | def test_mismatch_metrics(self, sample_data):
method test_veto_mechanism (line 165) | def test_veto_mechanism(self):
method test_metrics_only_mode (line 192) | def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
FILE: verl_distillation/tests/utils/_test_module.py
class TestClass (line 17) | class TestClass:
method __init__ (line 20) | def __init__(self, value=None):
method get_value (line 23) | def get_value(self):
function test_function (line 30) | def test_function():
FILE: verl_distillation/tests/utils/dataset/test_create_rl_sampler_on_cpu.py
class RandomCurriculumSampler (line 29) | class RandomCurriculumSampler(AbstractCurriculumSampler):
method __init__ (line 30) | def __init__(
method __iter__ (line 40) | def __iter__(self):
method __len__ (line 43) | def __len__(self) -> int:
method update (line 46) | def update(self, batch) -> None:
class MockIncorrectSampler (line 50) | class MockIncorrectSampler:
method __init__ (line 53) | def __init__(self, data_source, data_config):
class MockChatDataset (line 57) | class MockChatDataset(Dataset):
method __init__ (line 58) | def __init__(self):
method __getitem__ (line 70) | def __getitem__(self, index):
method __len__ (line 73) | def __len__(self):
function test_create_custom_curriculum_samper (line 77) | def test_create_custom_curriculum_samper():
function test_create_custom_curriculum_samper_wrong_class (line 94) | def test_create_custom_curriculum_samper_wrong_class():
FILE: verl_distillation/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py
function test_multiturn_sft_dataset (line 27) | def test_multiturn_sft_dataset():
FILE: verl_distillation/tests/utils/dataset/test_rl_collate_fn_on_cpu.py
function test_rl_collate_fn (line 17) | def test_rl_collate_fn():
FILE: verl_distillation/tests/utils/dataset/test_rl_dataset_on_cpu.py
function get_gsm8k_data (line 21) | def get_gsm8k_data():
function test_rl_dataset (line 29) | def test_rl_dataset():
function test_rl_dataset_with_max_samples (line 69) | def test_rl_dataset_with_max_samples():
function test_image_rl_data (line 88) | def test_image_rl_data():
FILE: verl_distillation/tests/utils/dataset/test_sft_dataset_on_cpu.py
function get_gsm8k_data (line 20) | def get_gsm8k_data():
function test_sft_cot_dataset (line 27) | def test_sft_cot_dataset():
function test_sft_dataset (line 52) | def test_sft_dataset():
function test_sft_dataset_with_max_samples (line 77) | def test_sft_dataset_with_max_samples():
FILE: verl_distillation/tests/utils/debug/test_metrics.py
class TestMetrics (line 22) | class TestMetrics(unittest.TestCase):
method test_calculate_debug_metrics (line 23) | def test_calculate_debug_metrics(self):
FILE: verl_distillation/tests/utils/megatron/test_pipeline_parallel.py
function test_make_batch_generator_no_vpp (line 21) | def test_make_batch_generator_no_vpp():
function test_make_batch_generator_with_vpp (line 28) | def test_make_batch_generator_with_vpp():
function test_make_batch_generator_empty (line 40) | def test_make_batch_generator_empty():
function test_get_dynamic_pipeline_shards (line 63) | def test_get_dynamic_pipeline_shards(layer_num, pp_size, gt):
FILE: verl_distillation/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py
function test_integration_success_correct (line 78) | def test_integration_success_correct():
function test_integration_success_wrong_output (line 89) | def test_integration_success_wrong_output():
function test_integration_compile_error (line 99) | def test_integration_compile_error():
function test_integration_runtime_error (line 108) | def test_integration_runtime_error():
function test_integration_runtime_timeout (line 117) | def test_integration_runtime_timeout():
function test_integration_concurrency_high_load (line 127) | def test_integration_concurrency_high_load():
function test_unit_concurrency_order (line 254) | def test_unit_concurrency_order(mock_call_sandbox_api):
function test_unit_api_timeout_error_concurrent (line 298) | def test_unit_api_timeout_error_concurrent(mock_call_sandbox_api):
function _mock_api_call_for_concurrency_tracking (line 351) | def _mock_api_call_for_concurrency_tracking(
function _process_pool_worker_for_concurrency_test (line 391) | def _process_pool_worker_for_concurrency_test(
function test_multiprocess_global_concurrency_limit_with_semaphore (line 458) | def test_multiprocess_global_concurrency_limit_with_semaphore():
function test_unit_invalid_input_format (line 556) | def test_unit_invalid_input_format():
function test_unit_input_output_mismatch (line 572) | def test_unit_input_output_mismatch():
function test_integration_concurrency_all_timeout (line 581) | def test_integration_concurrency_all_timeout():
function test_fn_name_success_single_case (line 633) | def test_fn_name_success_single_case():
function test_none_and_empty_stdin_passed_correctly (line 672) | def test_none_and_empty_stdin_passed_correctly():
function test_assert_case_success (line 696) | def test_assert_case_success():
FILE: verl_distillation/tests/utils/reward_score/test_sandbox_on_cpu.py
function test_parallelism (line 96) | def test_parallelism():
function test_prime_code (line 118) | def test_prime_code():
function test_prime_code_sandbox_fusion (line 130) | def test_prime_code_sandbox_fusion():
function test_continuous_score_consistency (line 147) | def test_continuous_score_consistency():
function test_check_correctness (line 173) | def test_check_correctness():
function test_prime_math (line 181) | def test_prime_math():
FILE: verl_distillation/tests/utils/test_activation_offload.py
function create_random_input_ids (line 32) | def create_random_input_ids(batch_size, seq_len, vocab_size):
function _fsdp_activation_offloading_test (line 49) | def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, ...
function test_activation_offloading (line 163) | def test_activation_offloading(world_size, strategy, tmp_path):
FILE: verl_distillation/tests/utils/test_config_on_cpu.py
class TestDataclass (line 25) | class TestDataclass(BaseConfig):
class TestTrainConfig (line 31) | class TestTrainConfig(BaseConfig):
class TestConfigOnCPU (line 46) | class TestConfigOnCPU(unittest.TestCase):
method setUp (line 55) | def setUp(self):
method test_omega_conf_to_dataclass (line 58) | def test_omega_conf_to_dataclass(self):
method test_nested_omega_conf_to_dataclass (line 65) | def test_nested_omega_conf_to_dataclass(self):
class TestPrintCfgCommand (line 74) | class TestPrintCfgCommand(unittest.TestCase):
method test_command_with_override (line 77) | def test_command_with_override(self):
FILE: verl_distillation/tests/utils/test_flops_counter.py
class Config (line 24) | class Config:
method __init__ (line 25) | def __init__(self, config_dict):
function test_flops_counter (line 234) | def test_flops_counter(config_type: str):
FILE: verl_distillation/tests/utils/test_fs_on_cpu.py
function test_record_and_check_directory_structure (line 21) | def test_record_and_check_directory_structure(tmp_path):
function test_copy_from_hdfs_with_mocks (line 43) | def test_copy_from_hdfs_with_mocks(tmp_path, monkeypatch):
function test_always_recopy_flag (line 66) | def test_always_recopy_flag(tmp_path, monkeypatch):
FILE: verl_distillation/tests/utils/test_groupwise.py
function test_as_torch_index_basic_integers (line 27) | def test_as_torch_index_basic_integers():
function test_as_torch_index_near_integer_floats (line 36) | def test_as_torch_index_near_integer_floats():
function test_as_torch_index_factorization_mixed (line 43) | def test_as_torch_index_factorization_mixed():
function test_group_mean_std_simple (line 51) | def test_group_mean_std_simple():
function test_group_mean_std_empty (line 68) | def test_group_mean_std_empty():
FILE: verl_distillation/tests/utils/test_import_utils_on_cpu.py
function test_load_extern_type_class (line 25) | def test_load_extern_type_class():
function test_load_extern_type_function (line 42) | def test_load_extern_type_function():
function test_load_extern_type_constant (line 55) | def test_load_extern_type_constant():
function test_load_extern_type_nonexistent_file (line 64) | def test_load_extern_type_nonexistent_file():
function test_load_extern_type_nonexistent_type (line 70) | def test_load_extern_type_nonexistent_type():
function test_load_extern_type_none_path (line 76) | def test_load_extern_type_none_path():
function test_load_extern_type_invalid_module (line 82) | def test_load_extern_type_invalid_module():
FILE: verl_distillation/tests/utils/test_linear_cross_entropy.py
function run_torch_entropy (line 48) | def run_torch_entropy(
function run_verl_original_entropy (line 64) | def run_verl_original_entropy(
function run_verl_torch_fused_entropy (line 82) | def run_verl_torch_fused_entropy(
class TestLinearCrossEntropy (line 99) | class TestLinearCrossEntropy:
method __init__ (line 100) | def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None:
method cleanup (line 104) | def cleanup(self):
method generate_hyper (line 112) | def generate_hyper(self):
method generate_forward_inputs (line 145) | def generate_forward_inputs(self):
method generate_backward_inputs (line 159) | def generate_backward_inputs(self):
method verify_correctness (line 164) | def verify_correctness(self, iterations=5):
method check_storage (line 322) | def check_storage(self, method_name, run_forward):
method check_storage_all (line 344) | def check_storage_all(self):
FILE: verl_distillation/tests/utils/test_mlflow_key_sanitization.py
class TestMlflowLoggingAdapter (line 21) | class TestMlflowLoggingAdapter(unittest.TestCase):
method test_sanitize_key_and_warning (line 22) | def test_sanitize_key_and_warning(self):
FILE: verl_distillation/tests/utils/test_model_on_cpu.py
function test_update_model_config (line 30) | def test_update_model_config(override_kwargs):
FILE: verl_distillation/tests/utils/test_nvtx_profile.py
class TestProfilerConfig (line 24) | class TestProfilerConfig(unittest.TestCase):
method test_config_init (line 25) | def test_config_init(self):
method test_frozen_config (line 52) | def test_frozen_config(self):
class TestNsightSystemsProfiler (line 74) | class TestNsightSystemsProfiler(unittest.TestCase):
method setUp (line 85) | def setUp(self):
method test_initialization (line 90) | def test_initialization(self):
method test_start_stop_profiling (line 94) | def test_start_stop_profiling(self):
method test_annotate_decorator (line 119) | def test_annotate_decorator(self):
FILE: verl_distillation/tests/utils/test_rollout_skip_on_cpu.py
function temp_dir (line 28) | def temp_dir():
function build_generate_fn (line 36) | def build_generate_fn(gen_bs, n):
function mock_rollout_wg (line 56) | def mock_rollout_wg(request):
class TestRolloutSkip (line 74) | class TestRolloutSkip:
method test_initialization (line 75) | def test_initialization(self, capsys):
method test_generate_without_wrap (line 95) | def test_generate_without_wrap(self, mock_rollout_wg):
method test_dump (line 110) | def test_dump(self, mock_rollout_wg, capsys):
method test_generate_with_wrap (line 125) | def test_generate_with_wrap(self, mock_rollout_wg, capsys):
FILE: verl_distillation/tests/utils/test_rollout_trace_on_cpu.py
function reset_rollout_trace_config_singleton (line 25) | def reset_rollout_trace_config_singleton():
function mock_weave_client (line 31) | def mock_weave_client():
class TracedClass (line 46) | class TracedClass:
method my_method (line 50) | async def my_method(self, a, b="default"):
method middle_method (line 56) | async def middle_method(self, a, b="default"):
method my_method_with_exception (line 62) | async def my_method_with_exception(self):
method upper_method (line 65) | async def upper_method(self):
class UntracedClass (line 71) | class UntracedClass:
method my_method (line 73) | async def my_method(self, x):
function test_rollout_trace_on_untraced_class (line 77) | async def test_rollout_trace_on_untraced_class():
function test_rollout_trace_with_tracer (line 83) | async def test_rollout_trace_with_tracer(mock_weave_client):
function test_rollout_trace_with_exception (line 102) | async def test_rollout_trace_with_exception(mock_weave_client):
function test_rollout_trace_with_dummy_backend (line 121) | async def test_rollout_trace_with_dummy_backend(mock_weave_client):
function test_rollout_trace_with_real_weave_backend (line 135) | async def test_rollout_trace_with_real_weave_backend():
function test_rollout_trace_with_real_mlflow_backend (line 156) | async def test_rollout_trace_with_real_mlflow_backend():
FILE: verl_distillation/tests/utils/test_seqlen_balancing.py
function test_seqlen_balancing (line 30) | def test_seqlen_balancing():
function test_dynamic_batch (line 49) | def test_dynamic_batch():
function _worker (line 63) | def _worker(rank, world_size, init_method, max_token_len, use_same_dp, m...
function test_dataproto_split_uneven (line 127) | def test_dataproto_split_uneven():
function test_seqlen_balancing_distributed_params (line 181) | def test_seqlen_balancing_distributed_params(tmp_path):
FILE: verl_distillation/tests/utils/test_special_linear_cross_entropy_tp.py
function run_torch_entropy (line 57) | def run_torch_entropy(
class TorchEntropyTP (line 79) | class TorchEntropyTP(torch.autograd.Function):
method forward (line 86) | def forward(
method backward (line 128) | def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor):
class TestLinearCrossEntropy_TensorParallel (line 181) | class TestLinearCrossEntropy_TensorParallel:
method __init__ (line 182) | def __init__(self):
method initialize (line 192) | def initialize(self, test_case_idx: int, temperature: float = 1.5):
method shutdown (line 196) | def shutdown(self):
method cleanup (line 199) | def cleanup(self):
method generate_hyper (line 207) | def generate_hyper(self):
method generate_forward_inputs (line 242) | def generate_forward_inputs(self):
method generate_backward_inputs (line 256) | def generate_backward_inputs(self):
method verify_torch_itself (line 261) | def verify_torch_itself(self, iterations: int = 5):
method check_torch_storage (line 331) | def check_torch_storage(self):
method verify_kernel_correctness (line 364) | def verify_kernel_correctness(self, iterations: int = 5):
method check_kernel_storage (line 455) | def check_kernel_storage(self):
FILE: verl_distillation/tests/utils/test_special_mstx_profile.py
class TestNPUProfilerInitialization (line 22) | class TestNPUProfilerInitialization(unittest.TestCase):
method setUp (line 23) | def setUp(self):
method test_init_with_default_config (line 26) | def test_init_with_default_config(self):
method test_init_with_disabled_config (line 32) | def test_init_with_disabled_config(self):
method test_init_with_all_ranks_true (line 39) | def test_init_with_all_ranks_true(self):
method test_init_with_ranks_list (line 45) | def test_init_with_ranks_list(self):
method test_init_with_rank_not_in_ranks (line 51) | def test_init_with_rank_not_in_ranks(self):
class TestNPUProfilerStart (line 58) | class TestNPUProfilerStart(unittest.TestCase):
method setUp (line 59) | def setUp(self):
method test_start_when_enabled_and_this_rank (line 65) | def test_start_when_enabled_and_this_rank(self, mock_get_profiler):
method test_start_when_not_this_rank (line 73) | def test_start_when_not_this_rank(self, mock_get_profiler):
method test_start_discrete_mode_does_not_increase_count (line 81) | def test_start_discrete_mode_does_not_increase_count(self, mock_get_pr...
method test_multiple_start_calls_do_not_increase_count (line 89) | def test_multiple_start_calls_do_not_increase_count(self, mock_get_pro...
class TestNPUProfilerStartStopInteraction (line 97) | class TestNPUProfilerStartStopInteraction(unittest.TestCase):
method setUp (line 98) | def setUp(self):
method test_start_stop_cycle (line 104) | def test_start_stop_cycle(self, mock_get_profiler):
method test_multiple_instances_share_define_count (line 118) | def test_multiple_instances_share_define_count(self, mock_get_profiler):
class TestNPUProfilerAnnotate (line 132) | class TestNPUProfilerAnnotate(unittest.TestCase):
method setUp (line 133) | def setUp(self):
method test_annotate_decorator_applied_correctly (line 138) | def test_annotate_decorator_applied_correctly(self):
method test_annotate_when_profiler_disabled (line 165) | def test_annotate_when_profiler_disabled(self):
method test_annotate_when_this_step_disabled (line 188) | def test_annotate_when_this_step_disabled(self):
method test_annotate_discrete_mode_enabled (line 211) | def test_annotate_discrete_mode_enabled(self):
method test_annotate_with_default_message (line 249) | def test_annotate_with_default_message(self):
FILE: verl_distillation/tests/utils/test_temp_env_on_cpu.py
function clean_env (line 23) | def clean_env():
function test_set_new_env_var (line 42) | def test_set_new_env_var():
function test_restore_existing_env_var (line 56) | def test_restore_existing_env_var():
function test_env_var_restored_on_exception (line 69) | def test_env_var_restored_on_exception():
function test_nested_context_managers (line 85) | def test_nested_context_managers():
function test_multiple_different_vars (line 103) | def test_multiple_different_vars():
function test_empty_string_value (line 118) | def test_empty_string_value():
function test_overwrite_with_empty_string (line 128) | def test_overwrite_with_empty_string():
function test_context_manager_returns_none (line 139) | def test_context_manager_returns_none():
FILE: verl_distillation/tests/utils/test_timeout_decorator_cpu.py
function quick_task (line 30) | def quick_task(x):
function slow_task (line 37) | def slow_task(x):
function task_raises_value_error (line 44) | def task_raises_value_error(): # Now truly not globally decorated
function top_level_decorated_quick_task_signal (line 52) | def top_level_decorated_quick_task_signal():
function top_level_decorated_slow_task_signal (line 62) | def top_level_decorated_slow_task_signal():
function run_target_and_put_in_queue (line 69) | def run_target_and_put_in_queue(target_func, q):
function set_macos_start_method (line 83) | def set_macos_start_method():
function test_quick_task (line 97) | def test_quick_task(): # Renamed from test_multiprocessing_quick_task
function test_slow_task_timeout (line 104) | def test_slow_task_timeout(): # Renamed from test_multiprocessing_slow_...
function test_internal_exception (line 113) | def test_internal_exception(): # Renamed from test_multiprocessing_inte...
function test_signal_quick_task_main_process (line 127) | def test_signal_quick_task_main_process(): # Removed self
function test_signal_slow_task_main_process_timeout (line 139) | def test_signal_slow_task_main_process_timeout(): # Removed self
function test_signal_in_thread_does_not_timeout (line 155) | def test_signal_in_thread_does_not_timeout():
function test_in_thread_timeout (line 200) | def test_in_thread_timeout():
FILE: verl_distillation/tests/utils/test_torch_functional.py
function _worker_mean (line 25) | def _worker_mean(rank: int, world_size: int, rendezvous_file: str):
function test_masked_mean (line 63) | def test_masked_mean(value, mask, gt):
function test_distributed_mean_max_min_std (line 70) | def test_distributed_mean_max_min_std(world_size, tmp_path):
function _worker_mask (line 82) | def _worker_mask(rank: int, world_size: int, rendezvous_file: str):
function test_distributed_masked_mean (line 108) | def test_distributed_masked_mean(world_size, tmp_path):
FILE: verl_distillation/tests/workers/actor/test_special_dp_actor.py
class MockTransformerModel (line 27) | class MockTransformerModel(nn.Module):
method __init__ (line 30) | def __init__(self, vocab_size=1000, hidden_size=64):
method forward (line 40) | def forward(self, input_ids, attention_mask=None, position_ids=None, u...
class TestDataParallelPPOActor (line 54) | class TestDataParallelPPOActor(unittest.TestCase):
method setUpClass (line 58) | def setUpClass(cls):
method setUp (line 74) | def setUp(self):
method tearDownClass (line 98) | def tearDownClass(cls):
method _create_test_data_for_compute_log_prob (line 103) | def _create_test_data_for_compute_log_prob(self):
method _create_test_data_for_update_policy (line 130) | def _create_test_data_for_update_policy(self):
method test_compute_log_prob (line 163) | def test_compute_log_prob(self):
method test_compute_log_prob_without_entropy (line 181) | def test_compute_log_prob_without_entropy(self):
method test_update_policy (line 196) | def test_update_policy(self):
method test_dataparallelppoactor_initialization (line 220) | def test_dataparallelppoactor_initialization(self):
method test_dataparallelppoactor_with_qwen3_model (line 230) | def test_dataparallelppoactor_with_qwen3_model(self):
FILE: verl_distillation/tests/workers/config/test_actor_config_on_cpu.py
class TestActorConfig (line 27) | class TestActorConfig(unittest.TestCase):
method test_config_inheritance (line 30) | def test_config_inheritance(self):
method test_actor_config_from_yaml (line 64) | def test_actor_config_from_yaml(self):
method test_fsdp_actor_config_from_yaml (line 76) | def test_fsdp_actor_config_from_yaml(self):
method test_megatron_actor_config_from_yaml (line 88) | def test_megatron_actor_config_from_yaml(self):
method test_config_get_method (line 100) | def test_config_get_method(self):
method test_config_dict_like_access (line 120) | def test_config_dict_like_access(self):
method test_frozen_fields_modification_raises_exception (line 143) | def test_frozen_fields_modification_raises_exception(self):
method test_actor_config_validation_exceptions (line 166) | def test_actor_config_validation_exceptions(self):
method test_fsdp_actor_config_validation_exceptions (line 208) | def test_fsdp_actor_config_validation_exceptions(self):
method test_actor_config_validate_method_exceptions (line 223) | def test_actor_config_validate_method_exceptions(self):
FILE: verl_distillation/tests/workers/config/test_critic_config_on_cpu.py
class TestCriticConfig (line 33) | class TestCriticConfig:
method config_dir (line 37) | def config_dir(self):
method test_megatron_critic_config_instantiation_from_yaml (line 41) | def test_megatron_critic_config_instantiation_from_yaml(self, config_d...
method test_fsdp_critic_config_instantiation_from_yaml (line 73) | def test_fsdp_critic_config_instantiation_from_yaml(self, config_dir):
method test_config_inheritance_hierarchy (line 106) | def test_config_inheritance_hierarchy(self):
method test_config_dict_interface (line 121) | def test_config_dict_interface(self):
method test_frozen_fields_immutability (line 138) | def test_frozen_fields_immutability(self):
method test_batch_size_fields_modifiable (line 161) | def test_batch_size_fields_modifiable(self):
method test_profiler_config_type_validation (line 182) | def test_profiler_config_type_validation(self):
method test_critic_config_validation_logic (line 210) | def test_critic_config_validation_logic(self):
method test_micro_batch_size_divisibility_validation (line 253) | def test_micro_batch_size_divisibility_validation(self):
method test_fsdp_sequence_parallelism_validation (line 278) | def test_fsdp_sequence_parallelism_validation(self):
FILE: verl_distillation/tests/workers/config/test_engine_config_on_cpu.py
class TestMcoreEngineConfig (line 20) | class TestMcoreEngineConfig:
method test_default_values (line 21) | def test_default_values(self):
method test_post_init_validation (line 27) | def test_post_init_validation(self):
method test_mutable_fields (line 36) | def test_mutable_fields(self):
method test_offload_flags (line 43) | def test_offload_flags(self, offload_field):
class TestFSDPEngineConfigCPU (line 48) | class TestFSDPEngineConfigCPU:
method test_default_values (line 49) | def test_default_values(self):
method test_offload_combinations (line 59) | def test_offload_combinations(self, offload_params):
method test_wrap_policy_configuration (line 64) | def test_wrap_policy_configuration(self):
FILE: verl_distillation/tests/workers/config/test_optim_config_on_cpu.py
class TestFSDPOptimizerConfigCPU (line 20) | class TestFSDPOptimizerConfigCPU:
method test_default_configuration (line 21) | def test_default_configuration(self):
method test_valid_lr_scheduler_types (line 28) | def test_valid_lr_scheduler_types(self, lr_scheduler_type):
method test_valid_warmup_style_types (line 33) | def test_valid_warmup_style_types(self, warmup_style):
method test_invalid_lr_scheduler_type (line 37) | def test_invalid_lr_scheduler_type(self):
method test_invalid_warmup_style_type (line 41) | def test_invalid_warmup_style_type(self):
method test_num_cycles_configuration (line 46) | def test_num_cycles_configuration(self, num_cycles):
FILE: verl_distillation/tests/workers/critic/test_special_dp_critic.py
class TestCriticWorker (line 33) | class TestCriticWorker(unittest.TestCase):
method setUpClass (line 35) | def setUpClass(cls):
method tearDownClass (line 52) | def tearDownClass(cls):
method setUp (line 57) | def setUp(self):
method tearDown (line 87) | def tearDown(self):
method _create_test_data_for_compute_values (line 93) | def _create_test_data_for_compute_values(self, batch_size=2, seq_len=1...
method _create_test_data_for_update_critic (line 118) | def _create_test_data_for_update_critic(self, batch_size=2, seq_len=10...
method test_init_model (line 148) | def test_init_model(self):
method test_compute_values (line 158) | def test_compute_values(self):
method test_update_critic (line 176) | def test_update_critic(self):
method test_critic_attn_implementation_override_functionality (line 201) | def test_critic_attn_implementation_override_functionality(self, mock_...
method test_critic_model_config_structure (line 259) | def test_critic_model_config_structure(self):
method test_critic_hydra_config_compatibility (line 289) | def test_critic_hydra_config_compatibility(self):
method test_critic_backward_compatibility (line 309) | def test_critic_backward_compatibility(self):
method test_critic_and_actor_independent_configuration (line 332) | def test_critic_and_actor_independent_configuration(self):
FILE: verl_distillation/tests/workers/reward_manager/test_registry_on_cpu.py
function setup (line 22) | def setup():
function test_get_existing_manager (line 29) | def test_get_existing_manager(setup):
function test_get_nonexistent_manager (line 35) | def test_get_nonexistent_manager(setup):
function test_case_sensitivity (line 42) | def test_case_sensitivity(setup):
function test_empty_registry (line 50) | def test_empty_registry(setup):
function test_register_new_class (line 58) | def test_register_new_class(setup):
function test_register_different_classes_same_name (line 69) | def test_register_different_classes_same_name(setup):
function test_decorator_returns_original_class (line 85) | def test_decorator_returns_original_class(setup):
FILE: verl_distillation/tests/workers/rollout/perf/vllm_async_rollout.py
function init_config (line 48) | def init_config(n_gpus_per_node) -> DictConfig:
function initialize (line 77) | def initialize(config, backend) -> tuple[AgentLoopManager | RayWorkerGro...
function perf_rollout (line 107) | def perf_rollout(mode, backend, n_gpus_per_node, num_steps):
FILE: verl_distillation/tests/workers/rollout/rollout_sglang/test_http_server_engine.py
function event_loop (line 63) | def event_loop():
function basic_adapter_kwargs (line 71) | def basic_adapter_kwargs():
function router_adapter_kwargs (line 82) | def router_adapter_kwargs():
function non_master_adapter_kwargs (line 95) | def non_master_adapter_kwargs():
function mock_launch_server_process (line 106) | def mock_launch_server_process():
function mock_multiprocessing_process (line 119) | def mock_multiprocessing_process():
function mock_requests_session (line 132) | def mock_requests_session():
function mock_requests_post (line 148) | def mock_requests_post():
function mock_requests_get (line 161) | def mock_requests_get():
function mock_aiohttp_session (line 174) | def mock_aiohttp_session():
function mock_kill_process_tree (line 193) | def mock_kill_process_tree():
function sglang_test_model_path (line 203) | def sglang_test_model_path():
function real_adapter_kwargs (line 215) | def real_adapter_kwargs(sglang_test_model_path):
function mock_server_args_post_init (line 226) | def mock_server_args_post_init():
class TestLaunchServerProcess (line 236) | class TestLaunchServerProcess:
method test_launch_server_process_success (line 239) | def test_launch_server_process_success(
method test_launch_server_process_non_master (line 264) | def test_launch_server_process_non_master(self, mock_multiprocessing_p...
method test_launch_server_process_timeout (line 279) | def test_launch_server_process_timeout(self, mock_multiprocessing_proc...
method test_launch_server_process_died (line 305) | def test_launch_server_process_died(self, real_adapter_kwargs):
class TestHttpServerEngineAdapter (line 322) | class TestHttpServerEngineAdapter:
method test_init_with_router_registration (line 325) | def test_init_with_router_registration(self, mock_launch_server_proces...
method test_init_without_router (line 334) | def test_init_without_router(self, mock_launch_server_process, basic_a...
method test_register_with_router_failure (line 342) | def test_register_with_router_failure(self, mock_launch_server_process...
method test_make_request_success (line 353) | def test_make_request_success(self, mock_launch_server_process, basic_...
method test_make_request_get_method (line 372) | def test_make_request_get_method(self, mock_launch_server_process, bas...
method test_make_request_non_master (line 387) | def test_make_request_non_master(self, mock_launch_server_process):
method test_make_request_retry_logic (line 395) | def test_make_request_retry_logic(self, mock_launch_server_process, ba...
method test_make_request_http_error (line 414) | def test_make_request_http_error(self, mock_launch_server_process, bas...
method test_make_request_max_attempts_exceeded (line 426) | def test_make_request_max_attempts_exceeded(self, mock_launch_server_p...
method test_update_weights_from_tensor_strict (line 439) | def test_update_weights_from_tensor_strict(self, mock_launch_server_pr...
method test_update_weights_from_tensor_empty (line 473) | def test_update_weights_from_tensor_empty(self, mock_launch_server_pro...
method test_update_weights_from_tensor_none (line 502) | def test_update_weights_from_tensor_none(self, mock_launch_server_proc...
method test_generate (line 531) | def test_generate(self, mock_launch_server_process, basic_adapter_kwar...
method test_flush_cache (line 555) | def test_flush_cache(self, mock_launch_server_process, basic_adapter_k...
method test_flush_cache_non_master (line 574) | def test_flush_cache_non_master(self, mock_launch_server_process):
method test_memory_management_methods (line 582) | def test_memory_management_methods(self, mock_launch_server_process, b...
method test_generation_control_methods (line 599) | def test_generation_control_methods(self, mock_launch_server_process, ...
method test_shutdown (line 606) | def test_shutdown(self, mock_launch_server_process, mock_kill_process_...
method test_shutdown_with_errors (line 622) | def test_shutdown_with_errors(self, mock_launch_server_process, mock_k...
method test_empty_and_none_parameters (line 643) | def test_empty_and_none_parameters(self, mock_launch_server_process, b...
method test_large_payload_handling (line 667) | def test_large_payload_handling(self, mock_launch_server_process, basi...
method test_timeout_edge_cases (line 690) | def test_timeout_edge_cases(self, mock_launch_server_process):
method test_extreme_configuration_values (line 702) | def test_extreme_configuration_values(self, mock_launch_server_process):
class TestAsyncHttpServerEngineAdapter (line 721) | class TestAsyncHttpServerEngineAdapter:
method test_init (line 724) | def test_init(self, mock_launch_server_process, basic_adapter_kwargs):
method test_make_async_request_success (line 731) | async def test_make_async_request_success(self, mock_launch_server_pro...
method test_make_async_request_get_method (line 764) | async def test_make_async_request_get_method(self, mock_launch_server_...
method test_make_async_request_non_master (line 793) | async def test_make_async_request_non_master(self, mock_launch_server_...
method test_async_generate (line 802) | async def test_async_generate(self, mock_launch_server_process, basic_...
method test_async_memory_management (line 819) | async def test_async_memory_management(self, mock_launch_server_proces...
class TestErrorRecovery (line 840) | class TestErrorRecovery:
method test_flush_cache_recovery (line 843) | def test_flush_cache_recovery(self, mock_launch_server_process, basic_...
method test_flush_cache_max_attempts (line 860) | def test_flush_cache_max_attempts(self, mock_launch_server_process, ba...
method test_network_partition_recovery (line 872) | def test_network_partition_recovery(self, mock_launch_server_process, ...
class TestResourceManagement (line 889) | class TestResourceManagement:
method test_resource_cleanup_on_exception (line 892) | def test_resource_cleanup_on_exception(
method test_multiple_shutdown_calls (line 909) | def test_multiple_shutdown_calls(self, mock_launch_server_process, bas...
class TestDataTypeHandling (line 919) | class TestDataTypeHandling:
method test_complex_data_structures (line 922) | def test_complex_data_structures(self, mock_launch_server_process, bas...
class TestIntegration (line 956) | class TestIntegration:
method test_error_scenarios (line 959) | def test_error_scenarios(self, mock_launch_server_process, basic_adapt...
FILE: verl_distillation/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py
function main (line 30) | def main():
FILE: verl_distillation/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py
function test_vllm_rollout_with_yarn_position_embeddings (line 32) | def test_vllm_rollout_with_yarn_position_embeddings():
function prepare_input_dataproto (line 104) | def prepare_input_dataproto(tokenizer, config, validate, do_sample=False):
FILE: verl_distillation/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py
function levenshtein (line 29) | def levenshtein(s1, s2):
function are_lists_similar (line 50) | def are_lists_similar(a, b):
function test_vllm_spmd (line 72) | def test_vllm_spmd():
FILE: verl_distillation/tests/workers/rollout/test_hf_rollout.py
function prepare_input_dataproto (line 48) | def prepare_input_dataproto(tokenizer, config, validate):
function prepare_fsdp_model (line 75) | def prepare_fsdp_model(model, world_size):
function test_hf_rollout (line 100) | def test_hf_rollout(n: int = 1, do_sample: bool = True, validate: bool =...
FILE: verl_distillation/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py
function get_search_messages (line 52) | def get_search_messages():
class TestRolloutWithMCPSearchTools (line 120) | class TestRolloutWithMCPSearchTools:
method qwen_tokenizer (line 124) | def qwen_tokenizer(self):
method qwen_model_config (line 131) | def qwen_model_config(self):
method search_data (line 136) | def search_data(self, qwen_tokenizer):
method search_rollout_config (line 150) | def search_rollout_config(self):
method search_data_proto (line 162) | def search_data_proto(self, search_data, qwen_tokenizer):
method mock_rollout (line 196) | def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_mod...
method test_tools_registration (line 291) | def test_tools_registration(self, mock_rollout):
method test_rollout_req_creation (line 300) | def test_rollout_req_creation(self, mock_rollout, search_data_proto):
method test_over_size_case (line 306) | def test_over_size_case(self, mock_rollout, search_data_proto, search_...
method test_tool_call_basic_case (line 351) | def test_tool_call_basic_case(self, mock_execute, mock_rollout, search...
method test_tool_call_batch_case (line 406) | def test_tool_call_batch_case(self, mock_execute, mock_rollout, search...
FILE: verl_distillation/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py
function _test_add_tool_response_messages_image_delta (line 31) | def _test_add_tool_response_messages_image_delta(processor, image_list, ...
function test_add_tool_response_messages_image_delta (line 157) | def test_add_tool_response_messages_image_delta():
function test_add_tool_response_messages_image_delta_resize_image (line 179) | def test_add_tool_response_messages_image_delta_resize_image():
FILE: verl_distillation/tests/workers/rollout/test_sglang_async_rollout_search_tools.py
function get_search_messages (line 56) | def get_search_messages():
class TestRolloutWithSearchTools (line 92) | class TestRolloutWithSearchTools:
method qwen_tokenizer (line 96) | def qwen_tokenizer(self):
method qwen_model_config (line 103) | def qwen_model_config(self):
method search_data (line 108) | def search_data(self, qwen_tokenizer):
method search_rollout_config (line 122) | def search_rollout_config(self):
method search_data_proto (line 134) | def search_data_proto(self, search_data, qwen_tokenizer):
method mock_rollout (line 171) | def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_mod...
method test_tools_registration (line 197) | def test_tools_registration(
method test_rollout_req_creation (line 218) | def test_rollout_req_creation(
method test_over_size_case (line 261) | def test_over_size_case(self, mock_rollout, search_data_proto, search_...
method test_tool_call_basic_case (line 304) | def test_tool_call_basic_case(self, mock_execute, mock_rollout, search...
method test_tool_call_batch_case (line 362) | def test_tool_call_batch_case(self, mock_execute, mock_rollout, search...
FILE: verl_distillation/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py
function get_sandbox_fusion_messages (line 46) | def get_sandbox_fusion_messages():
function skip_if_valid_sandbox (line 136) | def skip_if_valid_sandbox(url):
class TestRolloutWithTools (line 148) | class TestRolloutWithTools:
method qwen_tokenizer (line 152) | def qwen_tokenizer(self):
method qwen_model_config (line 159) | def qwen_model_config(self):
method sandbox_fusion_data (line 164) | def sandbox_fusion_data(self, qwen_tokenizer):
method sandbox_fusion_rollout_config (line 178) | def sandbox_fusion_rollout_config(self):
method sandbox_data_proto (line 190) | def sandbox_data_proto(self, sandbox_fusion_data, qwen_tokenizer):
method mock_rollout (line 223) | def mock_rollout(self, sandbox_fusion_rollout_config, qwen_tokenizer, ...
method test_tools_registration (line 245) | def test_tools_registration(self, mock_rollout):
method test_rollout_req_creation (line 254) | def test_rollout_req_creation(self, mock_rollout, sandbox_data_proto):
method test_over_size_case (line 281) | def test_over_size_case(self, mock_rollout, sandbox_data_proto, sandbo...
method test_tool_call_basic_case (line 327) | def test_tool_call_basic_case(self, mock_rollout, sandbox_data_proto, ...
method test_tool_call_batch_case (line 380) | def test_tool_call_batch_case(self, mock_rollout, sandbox_data_proto, ...
method test_sampling_params_functionality (line 447) | def test_sampling_params_functionality(self, mock_rollout):
class RayMultiProcessTestCase (line 464) | class RayMultiProcessTestCase(MultiProcessTestCase):
method setUp (line 465) | def setUp(self):
method tearDown (line 471) | def tearDown(self):
class TestActor (line 477) | class TestActor:
method __init__ (line 478) | def __init__(self, rank, world_size):
method record_rank (line 484) | def record_rank(self, rank):
method get_rank (line 487) | def get_rank(self):
method ping (line 490) | def ping(self):
method record_execution_time (line 493) | def record_execution_time(self, time):
method get_time (line 496) | def get_time(self, timeout):
method verify_rank (line 510) | def verify_rank(self):
class TestRayGlobalActorCase (line 528) | class TestRayGlobalActorCase(RayMultiProcessTestCase):
method world_size (line 530) | def world_size(self) -> int:
method test_basic_multi_process_init (line 534) | def test_basic_multi_process_init(self):
class TestSingleNodeRateLimiterCase (line 554) | class TestSingleNodeRateLimiterCase(RayMultiProcessTestCase):
method world_size (line 556) | def world_size(self) -> int:
method test_rate_limiter (line 559) | def test_rate_limiter(self):
method test_rotten_execution (line 592) | def test_rotten_execution(self):
class TestMultiNodeRateLimiterCase (line 620) | class TestMultiNodeRateLimiterCase(RayMultiProcessTestCase):
method world_size (line 622) | def world_size(self) -> int:
method test_rate_limiter (line 625) | def test_rate_limiter(self):
FILE: verl_distillation/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py
function test_async_sglang_rollout_w_interaction (line 40) | def test_async_sglang_rollout_w_interaction():
FILE: verl_distillation/tests/workers/rollout/test_sglang_async_rollout_w_tools.py
function test_async_sglang_rollout_w_tool (line 40) | def test_async_sglang_rollout_w_tool():
FILE: verl_distillation/tests/workers/rollout/test_sglang_async_rollout_w_tools_token_out.py
function test_async_sglang_rollout_w_tool (line 40) | def test_async_sglang_rollout_w_tool():
FILE: verl_distillation/tests/workers/rollout/test_sglang_multi_interaction.py
class MockInteraction (line 39) | class MockInteraction(BaseInteraction):
method __init__ (line 42) | def __init__(self, config):
method start_interaction (line 46) | async def start_interaction(self, instance_id=None, **kwargs):
method generate_response (line 52) | async def generate_response(self, instance_id, messages, **kwargs):
function create_mock_config_with_multi_interactions (line 56) | def create_mock_config_with_multi_interactions():
function setup_distributed (line 109) | def setup_distributed():
class TestSGLangMultiInteraction (line 115) | class TestSGLangMultiInteraction:
method test_initialize_multiple_interactions (line 118) | def test_initialize_multiple_interactions(self):
method test_interaction_selection_by_name (line 173) | def test_interaction_selection_by_name(self):
method test_fallback_to_default_interaction (line 244) | def test_fallback_to_default_interaction(self):
method test_error_on_missing_interaction (line 323) | def test_error_on_missing_interaction(self):
method test_backward_compatibility_no_interaction_config (line 366) | def test_backward_compatibility_no_interaction_config(self):
FILE: verl_distillation/tests/workers/rollout/test_sglang_rollout_sharding_manager.py
function test_get_named_tenso
Copy disabled (too large)
Download .json
Condensed preview — 1960 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (13,773K chars).
[
{
"path": ".gitignore",
"chars": 580,
"preview": "# IDE\n.idea/\n.vscode/\n.claude/\n.gemini/\n*.swp\n*.swo\n*~\n\n# OS\n.DS_Store\nThumbs.db\n\n# Python\n__pycache__/\n*.py[cod]\n*$py.c"
},
{
"path": "README.md",
"chars": 11449,
"preview": "<div align=\"center\">\n <h1>OpenOneRec</h1>\n <p align=\"center\">\n <strong>An Open Foundation Model and Benchmark to Ac"
},
{
"path": "benchmarks/LICENSE",
"chars": 11344,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "benchmarks/README.md",
"chars": 2908,
"preview": "# Benchmark\n\n\n## Quick Start\n\n### Step 1: Install Dependencies\n\n```bash\ncd benchmarks\n\nconda create -n benchmark python="
},
{
"path": "benchmarks/api/README.md",
"chars": 950,
"preview": "# Unified LLM API Wrapper\n\nThis is a unified LLM API wrapper library that provides a clean and elegant interface for cal"
},
{
"path": "benchmarks/api/__init__.py",
"chars": 4917,
"preview": "\"\"\"\nUnified LLM API Wrapper\nSupports convenient calling of Gemini, DeepSeek, and Claude models\n\"\"\"\nimport json\nfrom path"
},
{
"path": "benchmarks/api/base.py",
"chars": 7570,
"preview": "\"\"\"\nBase LLM Client Definition\nProvides unified interface specification with retry mechanism and batch processing\n\"\"\"\nfr"
},
{
"path": "benchmarks/api/claude.py",
"chars": 3048,
"preview": "\"\"\"\nClaude API Client Implementation\nBased on Anthropic official SDK\n\"\"\"\nfrom typing import Optional\nfrom anthropic impo"
},
{
"path": "benchmarks/api/config/llm_config.json",
"chars": 613,
"preview": "{\n \"gemini\": {\n \"project\": \"\",\n \"location\": \"\",\n \"model_name\": \"gemini-2.5-flash-lite\",\n \"credentials_path\""
},
{
"path": "benchmarks/api/deepseek.py",
"chars": 2874,
"preview": "\"\"\"\nDeepSeek API Client Implementation\nCall DeepSeek model through Baidu Qianfan platform\n\"\"\"\nfrom typing import Optiona"
},
{
"path": "benchmarks/api/example.py",
"chars": 8382,
"preview": "\"\"\"\nLLM API Usage Examples\nDemonstrates various calling methods and use cases\n\"\"\"\n\n# ==================================="
},
{
"path": "benchmarks/api/gemini.py",
"chars": 2760,
"preview": "\"\"\"\nGemini API Client Implementation\nBased on Google Vertex AI's Gemini model\n\"\"\"\nimport os\nfrom typing import Optional\n"
},
{
"path": "benchmarks/benchmark/__init__.py",
"chars": 244,
"preview": "from benchmark.benchmark import Benchmark\nfrom benchmark.base_generator import Generator\nfrom benchmark.generation_runn"
},
{
"path": "benchmarks/benchmark/base_generator.py",
"chars": 39097,
"preview": "import os\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, List, Any, Optional\nfrom collections import defau"
},
{
"path": "benchmarks/benchmark/benchmark.py",
"chars": 20241,
"preview": "import os\nimport json\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nfrom pathlib import Path\nfrom datetime "
},
{
"path": "benchmarks/benchmark/checkpoint_utils.py",
"chars": 15103,
"preview": "\"\"\"\nPT format model checkpoint loading tool\n\nSupports loading PyTorch model checkpoints in non-safetensor format\n\"\"\"\n\nim"
},
{
"path": "benchmarks/benchmark/console.py",
"chars": 465,
"preview": "from rich.console import Console\nfrom pyfiglet import Figlet\n\nconsole = Console()\nerr_style = \"bold red\"\nwarning_style ="
},
{
"path": "benchmarks/benchmark/generation_runner.py",
"chars": 12429,
"preview": "\"\"\"\nGeneration Runner\n\nResponsible for:\n1. Loading test data via data loader\n2. Calling Generator to produce model outpu"
},
{
"path": "benchmarks/benchmark/gpu_utils.py",
"chars": 4490,
"preview": "\"\"\"\nGPU hardware detection and FLOPS calculation utilities for MFU computation.\n\"\"\"\n\nfrom typing import Dict, Any, Optio"
},
{
"path": "benchmarks/benchmark/tasks/__init__.py",
"chars": 326,
"preview": "\"\"\"\nTasks definition for Benchmark\n\"\"\"\n\nfrom .tasks import (\n BenchmarkTable,\n check_benchmark_version,\n check_"
},
{
"path": "benchmarks/benchmark/tasks/tasks.py",
"chars": 4234,
"preview": "\"\"\"\nTask table and utility functions for Benchmark\n\"\"\"\n\nfrom typing import List, Optional, Tuple\nfrom benchmark.tasks.v1"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/__init__.py",
"chars": 97,
"preview": "\"\"\"\nv1.0 Version Task Definitions\n\"\"\"\n\nfrom .registry import TaskTable\n\n__all__ = [\"TaskTable\"]\n\n"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/base_evaluator.py",
"chars": 4987,
"preview": "\"\"\"\nBase Evaluator for all task evaluators\n\nProvides common interface for evaluation logic.\n\"\"\"\n\nimport json\nimport os\nf"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/base_loader.py",
"chars": 12555,
"preview": "\"\"\"\nBase Loader for all task data loaders\n\nProvides common functionality for data loading, sampling, and file path resol"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/item_understand/__init__.py",
"chars": 237,
"preview": "\"\"\"\nItem Understand Task Module\n\"\"\"\n\nfrom .config import ITEM_UNDERSTAND_CONFIG\nfrom .evaluator import ItemUnderstandEva"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/item_understand/config.py",
"chars": 1717,
"preview": "\"\"\"\nItem Understand Task Configuration\n\"\"\"\n\n# Item Understand Task Configuration\nITEM_UNDERSTAND_CONFIG = {\n \"name\": "
},
{
"path": "benchmarks/benchmark/tasks/v1_0/item_understand/evaluator.py",
"chars": 8312,
"preview": "\"\"\"\nItem Understand Evaluator\n\nEvaluates model predictions on Item Understand task using WIP (LLM-as-Judge).\n\"\"\"\n\nimport"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/item_understand/utils.py",
"chars": 39883,
"preview": "import json\nimport os\nimport re\nfrom typing import Dict, List, Any, Optional, Tuple\nfrom concurrent.futures import Threa"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/label_pred/__init__.py",
"chars": 353,
"preview": "\"\"\"\nLabel Prediction Task Module\n\nClassification task for predicting user engagement with video content.\nUses logprobs-b"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/label_pred/config.py",
"chars": 1424,
"preview": "\"\"\"\nLabel Prediction Task Configuration\n\nThis is a classification task for predicting user engagement with video content"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/label_pred/evaluator.py",
"chars": 9444,
"preview": "\"\"\"\nLabel Prediction Task Evaluator\n\nEvaluator for label_pred classification task.\nComputes AUC metric from logprobs-bas"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/label_pred/utils.py",
"chars": 7829,
"preview": "\"\"\"\nLabel Prediction Task Utilities\n\nFunctions for label extraction, probability processing, and AUC/wuAUC computation.\n"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/mfu_evaluator.py",
"chars": 6063,
"preview": "\"\"\"\nMFU (Model FLOPs Utilization) Evaluator\n\nComputes MFU metric based on:\n- Model parameters\n- Token statistics\n- GPU h"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/qwen3.jinja2",
"chars": 4168,
"preview": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].con"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/qwen3_soft_switch.jinja2",
"chars": 4529,
"preview": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].con"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/rec_reason/__init__.py",
"chars": 225,
"preview": "\"\"\"\nRecommendation Reason Task Module\n\"\"\"\n\nfrom .config import REC_REASON_CONFIG\nfrom .evaluator import RecoReasonEvalua"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/rec_reason/config.py",
"chars": 1399,
"preview": "\"\"\"\nRecommendation Reason Task Configuration\n\"\"\"\n\n# Recommendation Reason Task Configuration\nREC_REASON_CONFIG = {\n \""
},
{
"path": "benchmarks/benchmark/tasks/v1_0/rec_reason/evaluator.py",
"chars": 7185,
"preview": "\"\"\"\nRecommendation Reason Evaluator\n\nEvaluates model predictions on Recommendation Reason task using LLM-based multi-dim"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/rec_reason/utils.py",
"chars": 13807,
"preview": "\"\"\"\nRecommendation Reason LLM Evaluation Utilities\n\nProvides functions for extracting refined reasoning and multi-dimens"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/recommendation/__init__.py",
"chars": 978,
"preview": "\"\"\"\nRecommendation Task Module\n\nUniversal module for all recommendation tasks including:\n- label_cond: Predict next vide"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/recommendation/config.py",
"chars": 4394,
"preview": "\"\"\"\nRecommendation Task Configurations\n\nThis module contains configurations for all recommendation tasks including:\n- la"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/recommendation/evaluator.py",
"chars": 20862,
"preview": "\"\"\"\nRecommendation Task Evaluator\n\nUniversal evaluator for all recommendation tasks.\nComputes Pass@k and Position1_Pass@"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/recommendation/utils.py",
"chars": 11193,
"preview": "\"\"\"\nRecommendation Task Utilities\n\nFunctions for SID extraction and recommendation metrics computation.\n\"\"\"\n\nfrom typing"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/recommendation/utils_by_pid.py",
"chars": 15068,
"preview": "\"\"\"\nRecommendation Task Utilities (PID-based)\n\nFunctions for PID extraction and recommendation metrics computation using"
},
{
"path": "benchmarks/benchmark/tasks/v1_0/registry.py",
"chars": 5961,
"preview": "\"\"\"\nTask Registry - Unified Task Registration\n\nThis module consolidates:\n- loader_factory.py\n- evaluator_factory.py \n- "
},
{
"path": "benchmarks/eval_script.sh",
"chars": 4373,
"preview": "#!/bin/bash\n\n# Set common variables\nMODEL_PATH=$1\nVERSION=\"${VERSION:-v1.0}\"\nBASE_OUTPUT_DIR=\"${BENCHMARK_BASE_DIR}/resu"
},
{
"path": "benchmarks/pyproject.toml",
"chars": 825,
"preview": "[build-system]\nrequires = [\"setuptools>=45\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"onerec-"
},
{
"path": "benchmarks/requirements.txt",
"chars": 4604,
"preview": "absl-py==2.1.0\naccelerate==1.8.1\naiodns==3.6.1\naiohappyeyeballs==2.6.1\naiohttp==3.11.14\naiohttp-cors==0.8.0\naiosignal==1"
},
{
"path": "benchmarks/scripts/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "benchmarks/scripts/eval_dev_results.py",
"chars": 1138,
"preview": "import argparse\n\nfrom benchmark import Benchmark\n\n\ndef get_args():\n parser = argparse.ArgumentParser()\n parser.add"
},
{
"path": "benchmarks/scripts/init_ray.sh",
"chars": 1984,
"preview": "#!/bin/bash\n# Single Node Ray Initialization Script\n# Usage: bash init_ray.sh <HEAD_NODE_IP> <PORT> <RANK>\n# HEAD_NODE"
},
{
"path": "benchmarks/scripts/init_ray_cluster.sh",
"chars": 3981,
"preview": "#!/bin/bash\n# Multi-node Ray Cluster Initialization Script\n# Usage: bash init_ray_cluster.sh [--stop]\n# --stop: Stop R"
},
{
"path": "benchmarks/scripts/ray-vllm/evaluate.py",
"chars": 3579,
"preview": "from transformers import HfArgumentParser\nimport torch\n\nfrom benchmark import Benchmark\nfrom benchmark.console import *\n"
},
{
"path": "benchmarks/scripts/ray-vllm/utils/__init__.py",
"chars": 18,
"preview": "# Ray-vLLM Utils\n\n"
},
{
"path": "benchmarks/scripts/ray-vllm/utils/arguments.py",
"chars": 5502,
"preview": "from dataclasses import dataclass, field\nfrom typing import Optional, List\n\n\n@dataclass\nclass ModelConfig:\n \"\"\"Model "
},
{
"path": "benchmarks/scripts/ray-vllm/utils/generator.py",
"chars": 37947,
"preview": "import os\nimport ray\nimport math\nimport json\nfrom typing import Dict, List, Any, Optional\nfrom vllm import LLM, Sampling"
},
{
"path": "data/README.md",
"chars": 9193,
"preview": "# Dataset Documentation\n\nThis directory contains data processing scripts and dataset format specifications for the OpenO"
},
{
"path": "data/general_text/pretrain.csv",
"chars": 1056,
"preview": "dataname,sample_num,huggingface_repo\r\nNemotron_CC_Math_v1,15440682 ,https://huggingface.co/datasets/nvidia/Nemotron-CC-"
},
{
"path": "data/general_text/sft.csv",
"chars": 870,
"preview": "dataname,sample_num,huggingface_repo\r\nOpenMathReasoning,510163,https://huggingface.co/datasets/nvidia/OpenMathReasoning"
},
{
"path": "data/onerec_data/README.md",
"chars": 4335,
"preview": "# OneRec Data Processing Scripts\n\nThis directory contains data processing scripts for the OneRec project, converting raw"
},
{
"path": "data/onerec_data/pretrain/item_understand.py",
"chars": 3619,
"preview": "\"\"\"\nItem Understand Pretrain Task\nInput: caption parquet (pid, dense_caption) + pid2sid parquet\nOutput: LLM Pretrain for"
},
{
"path": "data/onerec_data/pretrain/user_profile.py",
"chars": 2041,
"preview": "\"\"\"\nUser Profile Pretrain Task\nInput: metadata parquet\nOutput: LLM Pretrain format parquet (segments)\n\nTask: Directly us"
},
{
"path": "data/onerec_data/pretrain/video_rec.py",
"chars": 3598,
"preview": "\"\"\"\nVideo Recommendation Pretrain Task\nInput: metadata parquet + pid2sid parquet\nOutput: LLM Pretrain format parquet (se"
},
{
"path": "data/onerec_data/run.sh",
"chars": 4941,
"preview": "#!/bin/bash\n# RecIF Data Processing Script\n# Generate all pretrain and SFT data\n\nset -e\n\n# ============== Task Selection"
},
{
"path": "data/onerec_data/sft/ad_rec.py",
"chars": 6028,
"preview": "\"\"\"\nAd Recommendation Task (Cross-domain)\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training format parq"
},
{
"path": "data/onerec_data/sft/interactive_rec.py",
"chars": 6071,
"preview": "\"\"\"\nInteractive Recommendation Task\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training format parquet\n\nT"
},
{
"path": "data/onerec_data/sft/item_understand.py",
"chars": 4320,
"preview": "\"\"\"\nItem Understand Task\nInput: caption parquet (pid, dense_caption) + pid2sid parquet\nOutput: LLM SFT training format p"
},
{
"path": "data/onerec_data/sft/label_cond_rec.py",
"chars": 7298,
"preview": "\"\"\"\nLabel Conditional Recommendation Task\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training format parq"
},
{
"path": "data/onerec_data/sft/label_pred.py",
"chars": 7766,
"preview": "\"\"\"\nLabel Prediction Task (Point-wise Classification)\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training"
},
{
"path": "data/onerec_data/sft/product_rec.py",
"chars": 6870,
"preview": "\"\"\"\nProduct Recommendation Task (Cross-domain)\nInput: metadata parquet + video_pid2sid parquet + product_pid2sid parquet"
},
{
"path": "data/onerec_data/sft/rec_reason.py",
"chars": 3614,
"preview": "\"\"\"\nRecommendation Reasoning Task\nInput: rec_reason parquet (user_profile_with_sid, gsu_caption, target_caption, cot, et"
},
{
"path": "data/onerec_data/sft/video_rec.py",
"chars": 4387,
"preview": "\"\"\"\nVideo Recommendation Task\nInput: metadata parquet + pid2sid parquet\nOutput: LLM SFT training format parquet\n\"\"\"\n\nimp"
},
{
"path": "data/prepare_distillation.sh",
"chars": 1195,
"preview": "#!/bin/bash\n# Data sampling script: Sample specified number of samples from general dataset for on-policy distillation\n\n"
},
{
"path": "data/prepare_pretrain.sh",
"chars": 950,
"preview": "#!/bin/bash\n# Data splitting script: Merge general text and recommendation data, then split by every 1000 samples\n\nset -"
},
{
"path": "data/prepare_rl.sh",
"chars": 2192,
"preview": "#!/bin/bash\n# RL data splitting script: Merge multiple RL task datasets and split into training and test sets\n\nset -e\n\n#"
},
{
"path": "data/prepare_sft.sh",
"chars": 935,
"preview": "#!/bin/bash\n# Data splitting script: Merge general text and recommendation data, then split by every 1000 samples\n\nset -"
},
{
"path": "data/scripts/parquet_unicode_fix.py",
"chars": 11707,
"preview": "#!/usr/bin/env python3\n\"\"\"Parquet Unicode Fix Script\n\nFix unicode Chinese garbled text issues in messages and segments f"
},
{
"path": "data/scripts/sample_data.py",
"chars": 8221,
"preview": "#!/usr/bin/env python3\n\"\"\"Data Sampling Script\n\nSample specified number of samples from one or more paths (directories o"
},
{
"path": "data/scripts/split_data.py",
"chars": 9579,
"preview": "#!/usr/bin/env python3\n\"\"\"Data splitting script\n\nMerge general text data and recommendation data, then split into multip"
},
{
"path": "data/scripts/train_test_split.py",
"chars": 9198,
"preview": "#!/usr/bin/env python3\n\"\"\"Train/Test Split Script\n\nRandomly selects N samples from multiple parquet files as the test se"
},
{
"path": "pretrain/.gitignore",
"chars": 263,
"preview": "# Python\n__pycache__/\n*.pyc\n*.so\n*.egg-info\n*.pylintrc\n\n# Build\nbuild\n\n# IDE\n.vscode/\n.idea/\n*~\n\n# OS\n.DS_Store\n\n# Proje"
},
{
"path": "pretrain/README.md",
"chars": 14251,
"preview": "# OpenOneRec Pretraining Module\n\nThe OpenOneRec pretraining module is based on the Qwen3 architecture, supporting a two-"
},
{
"path": "pretrain/examples/dataset_config/pretrain.json",
"chars": 510,
"preview": "{\n \"name\": \"chat_completion_parquet\",\n \"sources\": \"../output/split_data_pretrain/file_list.json\",\n \"only_assist"
},
{
"path": "pretrain/examples/dataset_config/sft.json",
"chars": 536,
"preview": "{\n \"name\": \"chat_completion_parquet\",\n \"sources\": \"../output/split_data_sft/file_list.json\",\n \"only_assistant_l"
},
{
"path": "pretrain/examples/posttrain_sft.sh",
"chars": 3255,
"preview": "sed 's/=1/=8/g' /etc/mpi/hostfile > /etc/mpi/hostfile_seq\n\n# MODEL_DIR=/code/hf_models/Qwen3-1.7B_itemic\nSTAGE2_OUTPUT_D"
},
{
"path": "pretrain/examples/pretrain_stg1.sh",
"chars": 3097,
"preview": "sed 's/=1/=8/g' /etc/mpi/hostfile > /etc/mpi/hostfile_seq\n\nMODEL_DIR=/code/hf_models/Qwen3-1.7B_itemic\nOUTPUT_DIR=/code/"
},
{
"path": "pretrain/examples/pretrain_stg2.sh",
"chars": 3166,
"preview": "sed 's/=1/=8/g' /etc/mpi/hostfile > /etc/mpi/hostfile_seq\n\n# MODEL_DIR=/code/hf_models/Qwen3-1.7B_itemic\nSTAGE1_OUTPUT_D"
},
{
"path": "pretrain/onerec_llm/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "pretrain/onerec_llm/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "pretrain/onerec_llm/data/dataloaders.py",
"chars": 1773,
"preview": "\nfrom torchdata.stateful_dataloader import StatefulDataLoader\nfrom onerec_llm.data.qwen3_dataset import Qwen3ChatComplet"
},
{
"path": "pretrain/onerec_llm/data/local_shuffle_buffer.py",
"chars": 6206,
"preview": "\"\"\"\nLocal shuffle buffer for data randomization during iteration.\n\nThis module provides a fixed-size buffer that randomi"
},
{
"path": "pretrain/onerec_llm/data/qwen3_dataset.py",
"chars": 30546,
"preview": "import logging\n\nimport os\nimport json\nimport time\nimport traceback\nimport random\nimport re\n\nimport multiprocessing\nimpor"
},
{
"path": "pretrain/onerec_llm/losses/__init__.py",
"chars": 133,
"preview": "from onerec_llm.losses.ce import CrossEntropyLoss, ChunkedLossComputer\n\n__all__ = [\n \"CrossEntropyLoss\",\n \"ChunkedLoss"
},
{
"path": "pretrain/onerec_llm/losses/ce.py",
"chars": 8096,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom onerec_llm.utils.time_tracker import TimeTracker"
},
{
"path": "pretrain/onerec_llm/models/qwen3/__init__.py",
"chars": 1003,
"preview": "# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License"
},
{
"path": "pretrain/onerec_llm/models/qwen3/configuration_qwen3.py",
"chars": 11232,
"preview": "# coding=utf-8\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Lic"
},
{
"path": "pretrain/onerec_llm/models/qwen3/modeling_qwen3.py",
"chars": 52663,
"preview": "# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# This file was automatically generated from"
},
{
"path": "pretrain/onerec_llm/models/qwen3/modular_qwen3.py",
"chars": 8190,
"preview": "# coding=utf-8\n# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Lic"
},
{
"path": "pretrain/onerec_llm/training/__init__.py",
"chars": 1461,
"preview": "\"\"\"Training utilities for FSDP-based LLM training.\n\nThis package provides core training functionality including:\n- Distr"
},
{
"path": "pretrain/onerec_llm/training/activations.py",
"chars": 1290,
"preview": "import torch.nn as nn\n\nfrom torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (\n apply_activation_ch"
},
{
"path": "pretrain/onerec_llm/training/checkpoint.py",
"chars": 17812,
"preview": "from typing import Dict, Any, Union, Optional, Protocol, Callable\nimport re\nimport os\nimport gc\nimport glob\nimport time\n"
},
{
"path": "pretrain/onerec_llm/training/common.py",
"chars": 496,
"preview": "\"\"\"Common training utilities for distributed model training.\"\"\"\n\nfrom typing import Generator\n\nimport contextlib\nimport "
},
{
"path": "pretrain/onerec_llm/training/distributed.py",
"chars": 5459,
"preview": "\"\"\"Distributed training utilities for FSDP model sharding and checkpoint loading.\"\"\"\n\nfrom typing import Any, Dict, Opti"
},
{
"path": "pretrain/onerec_llm/training/gradients.py",
"chars": 6962,
"preview": "\"\"\"Gradient computation and manipulation utilities for training.\n\nThis module provides utilities for gradient processing"
},
{
"path": "pretrain/onerec_llm/training/lr_schedulers.py",
"chars": 4022,
"preview": "\"\"\"Learning rate schedulers for training.\"\"\"\n\nimport math\nfrom functools import partial\nfrom typing import Optional\n\nfro"
},
{
"path": "pretrain/onerec_llm/utils/__init__.py",
"chars": 1502,
"preview": "\"\"\"Utility functions for LLM training.\n\nThis package provides general-purpose utilities including:\n- Common utilities (p"
},
{
"path": "pretrain/onerec_llm/utils/common.py",
"chars": 3909,
"preview": "\"\"\"Common utility functions for the onerec_llm package.\n\nThis module contains core utilities for:\n- Distributed training"
},
{
"path": "pretrain/onerec_llm/utils/data_utils.py",
"chars": 8469,
"preview": "\"\"\"Data loading utilities for parquet files and HDFS.\"\"\"\n\nimport hashlib\nimport os\nimport subprocess\nimport time\nimport "
},
{
"path": "pretrain/onerec_llm/utils/distributed.py",
"chars": 1670,
"preview": "\"\"\"Distributed training base utilities.\n\nThis module provides fundamental distributed training utilities that can be use"
},
{
"path": "pretrain/onerec_llm/utils/ds_utils.py",
"chars": 10650,
"preview": "\"\"\"Debug and formatting utilities for data structures and tensors.\"\"\"\n\nimport math\nimport os\nimport traceback\nfrom datac"
},
{
"path": "pretrain/onerec_llm/utils/mfu_stats.py",
"chars": 15266,
"preview": "\"\"\"Model FLOPs Utilization (MFU) statistics and calculation utilities.\n\nThis module provides functionality to calculate "
},
{
"path": "pretrain/onerec_llm/utils/time_tracker.py",
"chars": 3262,
"preview": "\"\"\"Time tracking utilities for performance profiling.\"\"\"\n\nimport os\nimport time\nfrom typing import Dict, List, Literal, "
},
{
"path": "pretrain/onerec_llm/utils/worker_utils.py",
"chars": 2235,
"preview": "\"\"\"Worker information utilities for PyTorch DataLoader and distributed training.\"\"\"\n\nimport os\nimport torch\nimport torch"
},
{
"path": "pretrain/recipes/train_qwen3.py",
"chars": 51417,
"preview": "\"\"\"Qwen3 Training Script\n\nMulti-node, multi-GPU training script for Qwen3 models using FSDP (Fully Sharded Data Parallel"
},
{
"path": "pretrain/scripts/convert_checkpoint_to_hf.sh",
"chars": 309,
"preview": "#!/bin/bash\n\nset -e\n\nBASE_MODEL_DIR=$1\nMODEL_HOME=$2\nSTEP=$3\nCKPT_DIR=${MODEL_HOME}/step${STEP}/global_step${STEP}\n\nOUTP"
},
{
"path": "pretrain/scripts/expand_qwen3_vocab.sh",
"chars": 410,
"preview": "#!/bin/bash\n\nset -e\n\nHF_MODEL_DIR=/code/onerec_pretrain/hf_models/Qwen3-0.6B\nOUTPUT_MODEL_DIR=/code/onerec_pretrain/hf_m"
},
{
"path": "pretrain/scripts/killall.sh",
"chars": 307,
"preview": "#!/bin/bash\n\nmpirun --allow-run-as-root --hostfile /etc/mpi/hostfile --pernode bash -c \"pkill -9 python3\"\nmpirun --allow"
},
{
"path": "pretrain/scripts/numa_runner.sh",
"chars": 323,
"preview": "#!/bin/bash\n\n# Get local NUMA node count\nnum_numa=$(numactl -H | grep \"node [0-9] cpus\" | wc -l)\nif [ \"$num_numa\" -lt 1 "
},
{
"path": "pretrain/scripts/test_cases_example.json",
"chars": 533,
"preview": "{\n \"test_cases\": [\n {\n \"type\": \"text\",\n \"input\": \"你好,请介绍一下你自己。\",\n \"ground_truth\": \"\"\n },\n {\n "
},
{
"path": "pretrain/scripts/test_hf_model.sh",
"chars": 2184,
"preview": "#!/bin/bash\n\n# HuggingFace Model Testing Script\n# Tests a HuggingFace model with text generation or chat mode\n# \n# Confi"
},
{
"path": "pretrain/set_env.sh",
"chars": 1318,
"preview": "#!/bin/bash\n\n# Check if current shell is bash\nif [ -z \"$BASH_VERSION\" ]; then\n echo \"This script must be run with bas"
},
{
"path": "pretrain/tests/test_qwen3_dataset_file_distribution.py",
"chars": 11979,
"preview": "\"\"\"\nTest file distribution logic for Qwen3ChatCompletionParquetDataset in multi-process, multi-worker scenarios\n\nValidat"
},
{
"path": "pretrain/tools/model_converter/convert_checkpoint_to_hf.py",
"chars": 16247,
"preview": "\"\"\"Checkpoint to HuggingFace Format Converter\n\nThis module provides utilities to convert PyTorch checkpoints (DCP or .pt"
},
{
"path": "pretrain/tools/model_converter/expand_qwen3_vocab.py",
"chars": 12251,
"preview": "\"\"\"Qwen3 Vocabulary Expansion Tool\n\nExpand the standard Qwen3 HuggingFace checkpoint vocabulary to support post-training"
},
{
"path": "pretrain/tools/model_test/test_hf_model.py",
"chars": 13388,
"preview": "#!/usr/bin/env python3\n\"\"\"HuggingFace Model Testing Tool\n\nA unified tool for testing HuggingFace models with both direct"
},
{
"path": "tokenizer/README.md",
"chars": 1865,
"preview": "# Residual K-Means Tokenizer\n\nA residual K-means model for vector quantization. It encodes continuous embeddings into di"
},
{
"path": "tokenizer/infer_res_kmeans.py",
"chars": 3757,
"preview": "import argparse\nimport torch\nimport numpy as np\nimport pandas as pd\nfrom res_kmeans import ResKmeans\n\n\ndef load_embeddin"
},
{
"path": "tokenizer/res_kmeans.py",
"chars": 2465,
"preview": "import torch\nfrom torch import nn\n\nclass ResKmeans(nn.Module):\n\n def __init__(self, n_layers, codebook_size, dim, ext"
},
{
"path": "tokenizer/train_res_kmeans.py",
"chars": 2316,
"preview": "import os\nimport argparse\nimport random\nimport numpy as np\nimport torch\nimport pyarrow.parquet as pq\nfrom tqdm import tq"
},
{
"path": "verl_distillation/LICENSE",
"chars": 11358,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "verl_distillation/README.md",
"chars": 5313,
"preview": "## Overview\n\nThis repository is built on top of the open-source [**verl**](https://github.com/volcengine/verl) (HybridFl"
},
{
"path": "verl_distillation/README_ORIGINAL.md",
"chars": 28570,
"preview": "<div align=\"center\">\n 👋 Hi, everyone! \n verl is a RL training library initiated by <b>ByteDance Seed team</b> and mai"
},
{
"path": "verl_distillation/deploy_env.sh",
"chars": 3443,
"preview": "#!/bin/bash\n# Multi-node Environment Deployment Script\n# Usage: bash deploy_env.sh [--all-nodes]\n\nset -e\n\nSCRIPT_DIR=$(c"
},
{
"path": "verl_distillation/docker/Apptainerfile.rocm",
"chars": 1519,
"preview": "Bootstrap: docker\n\n# Support - Traing: fsdp; Inference: vllm\n# FROM: rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6."
},
{
"path": "verl_distillation/docker/Dockerfile.extention.awsefa",
"chars": 2105,
"preview": "# Base Image support aws EFA\n# Build Image with frameworks based on this\nFROM verlai/verl:app-verl0.6-transformers4.56.1"
},
{
"path": "verl_distillation/docker/Dockerfile.ngc.vllm",
"chars": 1889,
"preview": "# docker buildx build --platform linux/x86_64 -t \"verlai/verl:ngc-th2.4.0-cu124-vllm0.6.3-ray2.4-te1.7-v0.0.6\" -f docker"
},
{
"path": "verl_distillation/docker/Dockerfile.ngc.vllm0.8",
"chars": 3336,
"preview": "# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/f"
},
{
"path": "verl_distillation/docker/Dockerfile.ngc.vllm0.8.sagemaker",
"chars": 1791,
"preview": "# Using a pre-built image from AWS DLC which contains the current version of python (3.10) and supported cuda version (1"
},
{
"path": "verl_distillation/docker/Dockerfile.rocm",
"chars": 10458,
"preview": "# FROM \"compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_r"
},
{
"path": "verl_distillation/docker/Dockerfile.rocm7",
"chars": 4795,
"preview": "# default base image\nARG REMOTE_VLLM=\"1\"\nARG COMMON_WORKDIR=/app\nARG BASE_IMAGE=rocm/vllm-dev:base\n\nFROM ${BASE_IMAGE} A"
},
{
"path": "verl_distillation/docker/Dockerfile.rocm_verl-0.3.0.post1",
"chars": 1499,
"preview": "# Build the docker in the repo dir:\n# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 .\n# docker images "
},
{
"path": "verl_distillation/docker/Dockerfile.rocm_verl-0.4.1",
"chars": 10480,
"preview": "# FROM \"compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_r"
},
{
"path": "verl_distillation/docker/Dockerfile.sglang",
"chars": 2402,
"preview": "# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)\n# https://docs.nvidia.com/deeplearning/frameworks/py"
},
{
"path": "verl_distillation/docker/Dockerfile.vemlp.vllm.te",
"chars": 1780,
"preview": "# docker buildx build --platform linux/x86_64 -t \"verlai/verl:$TAG\" -f docker/$FILE .\n\n# the one in docker.io is an alia"
},
{
"path": "verl_distillation/docker/Dockerfile.vllm.sglang.megatron.deepseek",
"chars": 5389,
"preview": "# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10)\n# https://docs.nvidia.com/deeplearning/f"
},
{
"path": "verl_distillation/docker/README.md",
"chars": 3930,
"preview": "# Dockerfiles of verl\n\nWe provide pre-built Docker images for quick setup. And from this version, we utilize a new image"
},
{
"path": "verl_distillation/docker/ascend/Dockerfile.ascend_8.2.rc1_a2",
"chars": 2631,
"preview": "# 1. Base Image\nFROM swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.2.rc1-910b-ubuntu22.04-py3.11\n\n# 2. Pre-installat"
},
{
"path": "verl_distillation/docker/ascend/Dockerfile.ascend_8.2.rc1_a3",
"chars": 2629,
"preview": "# 1. Base Image\nFROM swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.2.rc1-a3-ubuntu22.04-py3.11\n\n# 2. Pre-installatio"
},
{
"path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12",
"chars": 1926,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Defi"
},
{
"path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12.deepep",
"chars": 3447,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Defi"
},
{
"path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.13.preview",
"chars": 3453,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Defi"
},
{
"path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12",
"chars": 2229,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Defi"
},
{
"path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12.deepep",
"chars": 3750,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Defi"
},
{
"path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview",
"chars": 3663,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4\n\n# Defi"
},
{
"path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.base",
"chars": 5477,
"preview": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai"
},
{
"path": "verl_distillation/docker/verl0.4-cu124-torch2.6-fa2.7.4/README.md",
"chars": 1005,
"preview": "# verl image with verl v0.4.x\n\n## Important packages version\n\n```txt\ncuda==12.4\ncudnn==9.8.0\ntorch==2.6.0\nflash_attn=2.7"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.sglang0.4.10.post2.mcore0.13",
"chars": 1680,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4\n\n# De"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.sglang0.4.9.post6.mcore0.13",
"chars": 1679,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4\n\n# De"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.vllm.mcore0.13",
"chars": 1575,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4\n\n# De"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.app.vllm.mcore0.15",
"chars": 1627,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM iseekyan/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.7.4-h10"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/Dockerfile.base.torch2.7.1",
"chars": 6184,
"preview": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7-fa2.7.4/README.md",
"chars": 721,
"preview": "# verl image with verl v0.5\n\n## Important packages version\n\n```txt\ncuda==12.6\ncudnn==9.8.0\ntorch==2.7.1\nflash_attn=2.7.4"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.12",
"chars": 1724,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0\n\n# De"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.13.preview",
"chars": 1732,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0\n\n# De"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.base",
"chars": 6053,
"preview": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai"
},
{
"path": "verl_distillation/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/README.md",
"chars": 621,
"preview": "# verl image with verl v0.5\n\n## Important packages version\n\n```txt\ncuda==12.6\ncudnn==9.8.0\ntorch==2.7.1\nflash_attn=2.8.0"
},
{
"path": "verl_distillation/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.megatron",
"chars": 1691,
"preview": "# Start from the verl base image\n# Dockerfile.base\nFROM verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8"
},
{
"path": "verl_distillation/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.base",
"chars": 4587,
"preview": "# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks\n# Target: verlai"
},
{
"path": "verl_distillation/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/README.md",
"chars": 657,
"preview": "# verl image with verl v0.5\n\n## Important packages version\n\n```txt\ncuda==12.8\ncudnn==9.8.0\ntorch==2.7.1\nflash_attn=2.8.0"
},
{
"path": "verl_distillation/docker/verl0.6-cu128-torch2.8.0-fa2.7.4/Dockerfile.app.sglang",
"chars": 179,
"preview": "FROM verlai/verl:base-verl0.6-cu128-cudnn9.8-torch2.8.0-fa2.7.4\n\nRUN pip install --no-cache-dir \"sglang[all]==0.5.2\"\nRUN"
},
{
"path": "verl_distillation/docker/verl0.6-cu128-torch2.8.0-fa2.7.4/Dockerfile.base",
"chars": 4068,
"preview": "# Start from the NVIDIA official image (ubuntu-24.04 + cuda-12.8 + python-3.12)\n# https://docs.nvidia.com/deeplearning/f"
},
{
"path": "verl_distillation/docker/verl0.6-cu128-torch2.8.0-fa2.7.4/Dockerfile.vllm011.mcore_gpt-oss",
"chars": 514,
"preview": "FROM nvcr.io/nvidia/nemo:25.07.gpt_oss\n\nRUN git clone -b v0.11.0 --depth 1 https://github.com/vllm-project/vllm.git /opt"
},
{
"path": "verl_distillation/docs/Makefile",
"chars": 602,
"preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS =\nSPHI"
},
{
"path": "verl_distillation/docs/README.md",
"chars": 618,
"preview": "# verl documentations\n\n## Build the docs\n\n```bash\n# If you want to view auto-generated API docstring, please make sure v"
},
{
"path": "verl_distillation/docs/README_vllm0.7.md",
"chars": 3033,
"preview": "# Upgrading to vllm >= 0.7\n\nNote: verl+vllm 0.8.3 is now stable. Please see ``docs/README_vllm0.8.md`` for upgrade guide"
},
{
"path": "verl_distillation/docs/README_vllm0.8.md",
"chars": 1533,
"preview": "# Upgrading to vLLM >= 0.8\n\nLast updated: 05/04/2025.\n\n## Installation\n\nNote: This version of verl+vLLM 0.8+ supports **"
},
{
"path": "verl_distillation/docs/_static/custom.css",
"chars": 4891,
"preview": "/* Make the documentation use full screen width */\n.wy-nav-content {\n max-width: none !important;\n width: 100% !im"
},
{
"path": "verl_distillation/docs/_static/js/resizable-sidebar.js",
"chars": 8973,
"preview": "// Resizable sidebar functionality\ndocument.addEventListener('DOMContentLoaded', function() {\n const sidebar = docume"
},
{
"path": "verl_distillation/docs/_static/js/runllm-widget.js",
"chars": 618,
"preview": "document.addEventListener(\"DOMContentLoaded\", function () {\n var script = document.createElement(\"script\");\n scrip"
},
{
"path": "verl_distillation/docs/advance/agent_loop.rst",
"chars": 9955,
"preview": "Agent Loop\n==========\n\nLast updated: 07/17/2025.\n\n.. versionadded:: 0.4.2\n [status: alpha]\n\n.. warning::\n Agent Loop"
},
{
"path": "verl_distillation/docs/advance/attention_implementation.rst",
"chars": 4358,
"preview": ".. _attention-implementation-override:\n\nAttention Implementation Override\n==================================\n\nLast updat"
},
{
"path": "verl_distillation/docs/advance/checkpoint.rst",
"chars": 8695,
"preview": ".. _checkpoint-page:\n\nUsing Checkpoints to Support Fault Tolerance Training\n============================================"
},
{
"path": "verl_distillation/docs/advance/dpo_extension.rst",
"chars": 9704,
"preview": "Extend to other RL(HF) algorithms\n=================================\n\nLast updated: 02/25/2025.\n\nWe already implemented t"
},
{
"path": "verl_distillation/docs/advance/fsdp_extension.rst",
"chars": 4264,
"preview": "\nAdd models with the FSDP backend\n==================================\n\nLast updated: 02/09/2025.\n\nModel\n-----------------"
},
{
"path": "verl_distillation/docs/advance/fully_async.md",
"chars": 30679,
"preview": "# Recipe: Fully Async Policy Trainer\n\n**Author:** `https://github.com/meituan-search`\n\nLast updated: 10/18/2025.\n\nThis d"
},
{
"path": "verl_distillation/docs/advance/megatron_extension.rst",
"chars": 766,
"preview": "Add models with the Megatron-LM backend\n=========================================\n\nLast updated: 04/25/2025.\n\nModel\n----"
},
{
"path": "verl_distillation/docs/advance/one_step_off.md",
"chars": 14340,
"preview": "# Recipe: One Step Off Policy Async Trainer\n\n**Author:** `https://github.com/meituan-search`\n\nLast updated: 07/17/2025."
},
{
"path": "verl_distillation/docs/advance/placement.rst",
"chars": 456,
"preview": "Ray API Design Tutorial\n=======================================\n\nLast updated: 10/30/2024.\n\nWe provide a tutorial for ou"
},
{
"path": "verl_distillation/docs/advance/ppo_lora.rst",
"chars": 4907,
"preview": "RL(HF) algorithms with LoRA Support\n===========================================\n\nLast updated: 06/05/2025.\n\nWe support L"
},
{
"path": "verl_distillation/docs/advance/reward_loop.rst",
"chars": 7211,
"preview": "Reward Loop\n===========\n\n.. _yyding: https://yyding1.github.io\n\nAuthor: `Yuyang Ding <https://yyding1.github.io>`_\n\nLast"
},
{
"path": "verl_distillation/docs/advance/rollout_is.md",
"chars": 30320,
"preview": "# Rollout Importance Sampling\n\n**Author:** [Yingru Li](https://richardli.xyz/)\n\nLast updated: 10/27/2025.\n\nThis document"
},
{
"path": "verl_distillation/docs/advance/rollout_skip.rst",
"chars": 2285,
"preview": "RolloutSkip Function Usage Documentation\n========================================\n\nLast updated: 08/01/2025.\n\nApplicable"
},
{
"path": "verl_distillation/docs/advance/rollout_trace.rst",
"chars": 7763,
"preview": "Trace Function Usage Instructions\n========================================\n\nLast updated: 07/10/2025.\n\nApplicable Scenar"
},
{
"path": "verl_distillation/docs/advance/rope.rst",
"chars": 1168,
"preview": "RoPE Scaling override\n=======================================\n\nLast updated: 05/14/2025.\n\nSome models such as `Qwen/Qwen"
},
{
"path": "verl_distillation/docs/algo/baseline.md",
"chars": 8511,
"preview": "# Algorithm Baselines\n\nLast updated: 06/18/2025.\n\n## Math related datasets\n\n### GSM8k\n\nAssuming GSM8k/math dataset is pr"
},
{
"path": "verl_distillation/docs/algo/collabllm.md",
"chars": 6215,
"preview": "# Recipe: CollabLLM \n\nLast updated: 09/22/2025.\n\n> Open-Source Algorithm Implementation & Expriement Running: [Haiquan C"
},
{
"path": "verl_distillation/docs/algo/dapo.md",
"chars": 10521,
"preview": "# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO)\n\nLast updated: 06/19/2025.\n\n> Open-Source Algor"
},
{
"path": "verl_distillation/docs/algo/entropy.md",
"chars": 7904,
"preview": "# Recipe: Entropy Mechanism\n\nLast updated: 06/27/2025.\n\n\n<div align=\"center\">\n\n The Entropy Mechanism of Reinforcement "
},
{
"path": "verl_distillation/docs/algo/gpg.md",
"chars": 1563,
"preview": "# GPG: Group Policy Gradient\n\nLast updated: 07/03/2025.\n\nGroup Policy Gradient (GPG) is a minimalist reinforcement learn"
},
{
"path": "verl_distillation/docs/algo/grpo.md",
"chars": 5539,
"preview": "# Group Relative Policy Optimization (GRPO)\n\nLast updated: 05/31/2025.\n\nIn reinforcement learning, classic algorithms li"
}
]
// ... and 1760 more files (download for full content)
About this extraction
This page contains the full source code of the Kuaishou-OneRec/OpenOneRec GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 1960 files (12.5 MB), approximately 3.4M tokens, and a symbol index with 7561 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.