Full Code of GAIR-NLP/ToRL for AI

main 1db091d9cbd3 cached
221 files
1.6 MB
380.6k tokens
1618 symbols
1 requests
Download .txt
Showing preview only (1,758K chars total). Download the full file or copy to clipboard to get everything.
Repository: GAIR-NLP/ToRL
Branch: main
Commit: 1db091d9cbd3
Files: 221
Total size: 1.6 MB

Directory structure:
gitextract_g7clz9ov/

├── data/
│   └── torl_data/
│       ├── test.parquet
│       ├── test_0524.parquet
│       └── train.parquet
├── readme.md
├── requirements.txt
├── scripts/
│   └── torl_1.5b.sh
└── verl/
    ├── __init__.py
    ├── models/
    │   ├── README.md
    │   ├── __init__.py
    │   ├── llama/
    │   │   ├── __init__.py
    │   │   └── megatron/
    │   │       ├── __init__.py
    │   │       ├── checkpoint_utils/
    │   │       │   ├── __init__.py
    │   │       │   ├── llama_loader.py
    │   │       │   └── llama_saver.py
    │   │       ├── layers/
    │   │       │   ├── __init__.py
    │   │       │   ├── parallel_attention.py
    │   │       │   ├── parallel_decoder.py
    │   │       │   ├── parallel_linear.py
    │   │       │   ├── parallel_mlp.py
    │   │       │   └── parallel_rmsnorm.py
    │   │       └── modeling_llama_megatron.py
    │   ├── qwen2/
    │   │   ├── __init__.py
    │   │   └── megatron/
    │   │       ├── __init__.py
    │   │       ├── checkpoint_utils/
    │   │       │   ├── __init__.py
    │   │       │   ├── qwen2_loader.py
    │   │       │   └── qwen2_saver.py
    │   │       ├── layers/
    │   │       │   ├── __init__.py
    │   │       │   ├── parallel_attention.py
    │   │       │   ├── parallel_decoder.py
    │   │       │   ├── parallel_linear.py
    │   │       │   ├── parallel_mlp.py
    │   │       │   └── parallel_rmsnorm.py
    │   │       └── modeling_qwen2_megatron.py
    │   ├── registry.py
    │   ├── transformers/
    │   │   ├── __init__.py
    │   │   ├── llama.py
    │   │   ├── monkey_patch.py
    │   │   ├── qwen2.py
    │   │   └── qwen2_vl.py
    │   └── weight_loader_registry.py
    ├── protocol.py
    ├── single_controller/
    │   ├── __init__.py
    │   ├── base/
    │   │   ├── __init__.py
    │   │   ├── decorator.py
    │   │   ├── megatron/
    │   │   │   ├── __init__.py
    │   │   │   ├── worker.py
    │   │   │   └── worker_group.py
    │   │   ├── register_center/
    │   │   │   ├── __init__.py
    │   │   │   └── ray.py
    │   │   ├── worker.py
    │   │   └── worker_group.py
    │   └── ray/
    │       ├── __init__.py
    │       ├── base.py
    │       └── megatron.py
    ├── third_party/
    │   ├── __init__.py
    │   └── vllm/
    │       ├── __init__.py
    │       ├── vllm_spmd/
    │       │   ├── __init__.py
    │       │   └── dtensor_weight_loaders.py
    │       ├── vllm_v_0_3_1/
    │       │   ├── __init__.py
    │       │   ├── arg_utils.py
    │       │   ├── config.py
    │       │   ├── llm.py
    │       │   ├── llm_engine_sp.py
    │       │   ├── model_loader.py
    │       │   ├── model_runner.py
    │       │   ├── parallel_state.py
    │       │   ├── tokenizer.py
    │       │   ├── weight_loaders.py
    │       │   └── worker.py
    │       ├── vllm_v_0_4_2/
    │       │   ├── __init__.py
    │       │   ├── arg_utils.py
    │       │   ├── config.py
    │       │   ├── dtensor_weight_loaders.py
    │       │   ├── hf_weight_loader.py
    │       │   ├── llm.py
    │       │   ├── llm_engine_sp.py
    │       │   ├── megatron_weight_loaders.py
    │       │   ├── model_loader.py
    │       │   ├── model_runner.py
    │       │   ├── parallel_state.py
    │       │   ├── spmd_gpu_executor.py
    │       │   ├── tokenizer.py
    │       │   └── worker.py
    │       ├── vllm_v_0_5_4/
    │       │   ├── __init__.py
    │       │   ├── arg_utils.py
    │       │   ├── config.py
    │       │   ├── dtensor_weight_loaders.py
    │       │   ├── hf_weight_loader.py
    │       │   ├── llm.py
    │       │   ├── llm_engine_sp.py
    │       │   ├── megatron_weight_loaders.py
    │       │   ├── model_loader.py
    │       │   ├── model_runner.py
    │       │   ├── parallel_state.py
    │       │   ├── spmd_gpu_executor.py
    │       │   ├── tokenizer.py
    │       │   └── worker.py
    │       └── vllm_v_0_6_3/
    │           ├── __init__.py
    │           ├── arg_utils.py
    │           ├── config.py
    │           ├── dtensor_weight_loaders.py
    │           ├── hf_weight_loader.py
    │           ├── llm.py
    │           ├── llm_engine_sp.py
    │           ├── megatron_weight_loaders.py
    │           ├── model_loader.py
    │           ├── model_runner.py
    │           ├── parallel_state.py
    │           ├── spmd_gpu_executor.py
    │           ├── tokenizer.py
    │           └── worker.py
    ├── trainer/
    │   ├── __init__.py
    │   ├── config/
    │   │   ├── evaluation.yaml
    │   │   ├── generation.yaml
    │   │   ├── ppo_megatron_trainer.yaml
    │   │   ├── ppo_trainer.yaml
    │   │   └── sft_trainer.yaml
    │   ├── fsdp_sft_trainer.py
    │   ├── main_eval.py
    │   ├── main_generation.py
    │   ├── main_ppo.py
    │   ├── ppo/
    │   │   ├── __init__.py
    │   │   ├── core_algos.py
    │   │   └── ray_trainer.py
    │   └── runtime_env.yaml
    ├── utils/
    │   ├── __init__.py
    │   ├── checkpoint/
    │   │   ├── __init__.py
    │   │   ├── checkpoint_manager.py
    │   │   └── fsdp_checkpoint_manager.py
    │   ├── config.py
    │   ├── dataset/
    │   │   ├── README.md
    │   │   ├── __init__.py
    │   │   ├── rl_dataset.py
    │   │   ├── rm_dataset.py
    │   │   └── sft_dataset.py
    │   ├── debug/
    │   │   ├── __init__.py
    │   │   ├── performance.py
    │   │   └── trajectory_tracker.py
    │   ├── distributed.py
    │   ├── flops_counter.py
    │   ├── fs.py
    │   ├── fsdp_utils.py
    │   ├── hdfs_io.py
    │   ├── import_utils.py
    │   ├── logger/
    │   │   ├── __init__.py
    │   │   └── aggregate_logger.py
    │   ├── logging_utils.py
    │   ├── megatron/
    │   │   ├── __init__.py
    │   │   ├── memory.py
    │   │   ├── optimizer.py
    │   │   ├── pipeline_parallel.py
    │   │   ├── sequence_parallel.py
    │   │   └── tensor_parallel.py
    │   ├── megatron_utils.py
    │   ├── memory_buffer.py
    │   ├── model.py
    │   ├── py_functional.py
    │   ├── ray_utils.py
    │   ├── rendezvous/
    │   │   ├── __init__.py
    │   │   └── ray_backend.py
    │   ├── reward_score/
    │   │   ├── __init__.py
    │   │   ├── eval.py
    │   │   ├── geo3k.py
    │   │   ├── gsm8k.py
    │   │   ├── math.py
    │   │   ├── math_verifier.py
    │   │   ├── prime_code/
    │   │   │   ├── __init__.py
    │   │   │   ├── testing_util.py
    │   │   │   └── utils.py
    │   │   └── prime_math/
    │   │       ├── __init__.py
    │   │       ├── grader.py
    │   │       └── math_normalize.py
    │   ├── seqlen_balancing.py
    │   ├── tokenizer.py
    │   ├── torch_dtypes.py
    │   ├── torch_functional.py
    │   ├── tracking.py
    │   └── ulysses.py
    ├── version/
    │   └── version
    └── workers/
        ├── __init__.py
        ├── actor/
        │   ├── __init__.py
        │   ├── base.py
        │   ├── dp_actor.py
        │   └── megatron_actor.py
        ├── critic/
        │   ├── __init__.py
        │   ├── base.py
        │   ├── dp_critic.py
        │   └── megatron_critic.py
        ├── fsdp_workers.py
        ├── megatron_workers.py
        ├── reward_manager/
        │   ├── __init__.py
        │   ├── naive.py
        │   └── prime.py
        ├── reward_model/
        │   ├── __init__.py
        │   ├── base.py
        │   └── megatron/
        │       ├── __init__.py
        │       └── reward_model.py
        ├── rollout/
        │   ├── __init__.py
        │   ├── base.py
        │   ├── hf_rollout.py
        │   ├── naive/
        │   │   ├── __init__.py
        │   │   └── naive_rollout.py
        │   ├── tokenizer.py
        │   └── vllm_rollout/
        │       ├── __init__.py
        │       ├── fire_vllm_rollout.py
        │       ├── qwen_agent/
        │       │   ├── code/
        │       │   │   ├── code_interpreter.py
        │       │   │   └── utils/
        │       │   │       └── code_utils.py
        │       │   ├── llm/
        │       │   │   └── schema.py
        │       │   ├── log.py
        │       │   ├── settings.py
        │       │   ├── tools/
        │       │   │   ├── base.py
        │       │   │   ├── code_interpreter.py
        │       │   │   └── python_executor.py
        │       │   └── utils/
        │       │       └── utils.py
        │       ├── vllm_rollout.py
        │       └── vllm_rollout_spmd.py
        └── sharding_manager/
            ├── __init__.py
            ├── base.py
            ├── fsdp_ulysses.py
            ├── fsdp_vllm.py
            └── megatron_vllm.py

================================================
FILE CONTENTS
================================================

================================================
FILE: readme.md
================================================
# 
<div align="center">

# ToRL: Scaling Tool-Integrated RL

</div>

<p align="center">
  📄 <a href="https://arxiv.org/pdf/2503.23383" target="_blank">Paper</a> &nbsp; | &nbsp;
  🌐 <a href="https://github.com/GAIR-NLP/ToRL/tree/main/data/torl_data" target="_blank">Dataset</a> &nbsp; | &nbsp;
  📘 <a href="https://huggingface.co/GAIR/ToRL-7B" target="_blank">Model</a>
</p>


<div align="center">
<img src="assets/abstract_aime.png" width="700" alt="torl-abstarct-1">
</div>

> Performance comparison of ToRL versus baseline models(16-step moving). Both plots show AIME24 Accuracy (%) against training steps across 1.5B and 7B models. In both cases, **ToRL(Ours)** significantly outperforms the baseline without tool and Qwen-2.5-Math-Instruct-TIR, achieving up to **12\%(1.5B)** and **14\%(7B)** higher.


<div align="center">
<img src="assets/abstract_case.png" width="700" alt="torl-abstarct-2">
</div>

> **Emergent cognitive behavior** during training. ToRL first cross-validates the tool's output with reasoning results. Upon detecting inconsistencies, it engages in reflection and further verification through tool calls.


## Releases

[2025/03/28] We're releasing the following components:

- 🚀 **Training**: Complete implementation of our training pipeline
- 🔥 **[ToRL Dataset](https://github.com/GAIR-NLP/ToRL/tree/main/data/torl_data)**: Our curated dataset of 28k mathematical questions
- 🤖 **[ToRL Model](https://huggingface.co/GAIR/ToRL)**: Model training with ToRL.

## Overview

This repository presents **ToRL (Tool-Integrated Reinforcement Learning)**, a framework that challenges traditional approaches to tool integration in language models by enabling LLMs to autonomously discover and refine tool usage strategies through reinforcement learning. Unlike prior methods constrained by supervised fine-tuning or predefined tool patterns, ToRL demonstrates that **exploration-driven learning** with computational tools can unlock emergent cognitive behaviors and achieve state-of-the-art performance on complex reasoning tasks. Notably, our approach operates directly from base models without imitation learning, achieving **43.3% accuracy on AIME2024** with a 7B model—matching the performance of larger 32B models trained with RL.  


### Key Findings  
- **Autonomous Tool Integration**: Models learn **when and how** to invoke tools (e.g., code interpreters) through RL-driven exploration, eliminating dependency on human-curated tool usage patterns.  
- **Emergent Cognitive Abilities**:  
  - Self-correction by cross-validating code execution results with reasoning steps  
  - Adaptive strategy selection between tool-based and pure-reasoning approaches  
  - Self-regulation of ineffective tool calls without explicit supervision  

### ToRL Performance
1.5B Model Performance across challenging mathematical benchmarks:

| Model                            | SFT/RL | Tool | AIME24 | AIME25 | MATH500 | Olympiad | AMC23 | Avg   |
|----------------------------------|--------|------|--------|--------|---------|----------|-------|-------|
| Qwen2.5-Math-1.5B-Instruct       | RL     | ❌    | 10.0   | 10.0   | 66.0    | 31.0     | 62.5  | 35.9  |
| Qwen2.5-Math-1.5B-Instruct-TIR   | RL     |  ✅    | 13.3   | 13.3   | 73.8    | 41.3     | 55.0  | 41.3  |
| **ToRL-1.5B(Ours)**              | RL     |  ✅    | **26.7 (+13.3)** | **26.7 (+13.3)** | **77.8 (+3.0)** | **44.0 (+2.7)** | **67.5 (+5.0)** | **48.5 (+7.2)** |


7B Model Performance across challenging mathematical benchmarks:
| Model                            | SFT/RL | Tool | AIME24 | AIME25 | MATH500 | Olympiad | AMC23 | Avg   |
|----------------------------------|--------|------|--------|--------|---------|----------|-------|-------|
| Qwen2.5-Math-7B-Instruct         | RL     | ❌    | 10.0   | 16.7   | 74.8    | 32.4     | 65.0  | 39.8  |
| Qwen2.5-Math-7B-Instruct-TIR     | RL     |  ✅    | 26.7   | 16.7   | 78.8    | 45.0     | 70.0  | 47.4  |
| SimpleRL-Zero                     | RL     | ❌    | 33.3   | 6.7    | 77.2    | 37.6     | 62.5  | 43.5  |
| rStar-Math-7B                    | SFT    | ❌    | 26.7   | -      | 78.4    | 47.1     | 47.5  | -     |
| Eurus-2-7B-PRIME                  | RL     | ❌    | 26.7   | 13.3   | 79.2    | 42.1     | 57.4  | 43.1  |
| **ToRL-7B(Ours)**                 | RL     |  ✅    | **43.3 (+10.0)** | **30.0 (+13.3)** | **82.2 (+3.0)** | **49.9 (+2.8)** | **75.0 (+5.0)** | **62.1 (+14.7)** |

### Cognitive Behavior via RL Scaling

In the left figure, the code initially generate by the model encountered an execution error, then it is corrected by model and is successfully executed.

In the right figure, the model first derives an incorrect result based on natural language reasoning, then discovers the error during code verification and makes corrections

<div style="display: flex; justify-content: center; gap: 20px;">
    <img src="assets/case1.png" width="350" alt="cases">
    <img src="assets/case2.png" width="350" alt="cases2">
</div>




## Quick Start

### Preparing the Sandbox Environment  
According to the instructions provided in [https://github.com/bytedance/SandboxFusion](https://github.com/bytedance/SandboxFusion), install SandboxFusion and launch it.  

```
# The following command can be used to install the sandbox environment 
# to avoid dependency conflicts.  
# The sandbox environment must be named "sandbox-runtime".  
conda create -n sandbox-runtime python==3.11  
pip install -r runtime/python/requirement.txt  

pip install poetry  
poetry install  
mkdir -p docs/build  
make run-online  
```  

Replace the `sandbox_url` on line `109` of `verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py` with the your sandbox.  

### Environment setup
```
pip install -r requirements.txt
pip install wandb jsonlines math-verify hydra-core==1.4.0.dev1 sortedcontainers qwen-agent[code_interpreter] qwen-agent[python_executor]
```

### Training

Execute `bash scripts/torl_1.5b` to run ToRL.


## Acknowledgements

Our work builds upon the insightful technical reports from [DeepSeek R1](https://github.com/deepseek-ai/DeepSeek-R1) and [Kimi-k1.5](https://github.com/MoonshotAI/Kimi-k1.5) teams. We extend our appreciation to the [Qwen-Math](https://github.com/QwenLM/Qwen2.5-Math) team for their open-source model, to the creators of [VeRL](https://github.com/volcengine/verl) and [vLLM](https://github.com/vllm-project/vllm) for providing the essential reinforcement learning framework and inference infrastructure, respectively, that enabled this research. and to the [Qwen-Agent](https://github.com/QwenLM/Qwen-Agent) and [Sandbox Fusion](https://github.com/bytedance/SandboxFusion) team, which provided the necessary tools for our research.

## Citation

If you find this work useful, please cite our paper:

```bibtex
@misc{li2025torlscalingtoolintegratedrl,
      title={ToRL: Scaling Tool-Integrated RL}, 
      author={Xuefeng Li and Haoyang Zou and Pengfei Liu},
      year={2025},
      eprint={2503.23383},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2503.23383}, 
}
```


================================================
FILE: requirements.txt
================================================
accelerate==1.5.2
aiohappyeyeballs==2.6.1
aiohttp==3.11.14
aiohttp-cors==0.8.0
aiosignal==1.3.2
airportsdata==20250224
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.9.0
astor==0.8.1
async-timeout==5.0.1
attrs==25.3.0
blake3==1.0.4
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
cloudpickle==3.1.1
codetiming==1.4.0
colorful==0.5.6
compressed-tensors==0.9.2
cupy-cuda12x==13.4.1
datasets==3.4.1
depyf==0.18.0
dill==0.3.8
diskcache==5.6.3
distlib==0.3.9
distro==1.9.0
dnspython==2.7.0
docker-pycreds==0.4.0
einops==0.8.1
email_validator==2.2.0
exceptiongroup==1.2.2
fastapi==0.115.11
fastapi-cli==0.0.7
fastrlock==0.8.3
filelock==3.18.0
flash_attn==2.7.4.post1
frozenlist==1.5.0
fsspec==2024.12.0
gguf==0.10.0
gitdb==4.0.12
GitPython==3.1.44
google-api-core==2.24.2
google-auth==2.38.0
googleapis-common-protos==1.69.2
grpcio==1.71.0
h11==0.14.0
httpcore==1.0.7
httptools==0.6.4
httpx==0.28.1
huggingface-hub==0.29.3
hydra-core==1.3.2
idna==3.10
importlib_metadata==8.6.1
interegular==0.3.3
Jinja2==3.1.6
jiter==0.9.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
lark==1.2.2
liger_kernel==0.5.5
llvmlite==0.43.0
lm-format-enforcer==0.10.11
markdown-it-py==3.0.0
MarkupSafe==3.0.2
mdurl==0.1.2
mistral_common==1.5.4
mpmath==1.3.0
msgpack==1.1.0
msgspec==0.19.0
multidict==6.2.0
multiprocess==0.70.16
nest-asyncio==1.6.0
networkx==3.4.2
ninja==1.11.1.4
numba==0.60.0
numpy==1.26.4
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
omegaconf==2.3.0
openai==1.68.2
opencensus==0.11.4
opencensus-context==0.1.3
opencv-python-headless==4.11.0.86
orjson==3.10.15
outlines==0.1.11
outlines_core==0.1.26
packaging==24.2
pandas==2.2.3
partial-json-parser==0.2.1.1.post5
peft==0.15.0
pillow==11.1.0
platformdirs==4.3.7
prometheus-fastapi-instrumentator==7.1.0
prometheus_client==0.21.1
propcache==0.3.0
proto-plus==1.26.1
protobuf==5.29.4
psutil==7.0.0
py-cpuinfo==9.0.0
py-spy==0.4.0
pyarrow==19.0.1
pyasn1==0.6.1
pyasn1_modules==0.4.1
pybind11==2.13.6
pycountry==24.6.1
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pylatexenc==2.10
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==3.3.0
python-multipart==0.0.20
pytz==2025.1
PyYAML==6.0.2
pyzmq==26.3.0
ray==2.44.0
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
rich==13.9.4
rich-toolkit==0.13.2
rpds-py==0.23.1
rsa==4.9
safetensors==0.5.3
scipy==1.15.2
sentencepiece==0.2.0
sentry-sdk==2.24.0
setproctitle==1.3.5
shellingham==1.5.4
six==1.17.0
smart-open==7.1.0
smmap==5.0.2
sniffio==1.3.1
starlette==0.46.1
sympy==1.13.1
tensordict==0.6.0
tiktoken==0.9.0
tokenizers==0.21.1
torch==2.6.0
torchaudio==2.6.0
torchdata==0.11.0
torchvision==0.21.0
tqdm==4.67.1
transformers==4.50.0
triton==3.2.0
typer==0.15.2
typing_extensions==4.12.2
tzdata==2025.1
urllib3==2.3.0
uvicorn==0.34.0
uvloop==0.21.0
virtualenv==20.29.3
vllm==0.8.1
wandb==0.19.8
watchfiles==1.0.4
websockets==15.0.1
wrapt==1.17.2
xformers==0.0.29.post2
xgrammar==0.1.16
xxhash==3.5.0
yarl==1.18.3
zipp==3.21.0


================================================
FILE: scripts/torl_1.5b.sh
================================================
policy_path=Qwen/Qwen2.5-Math-1.5B
rollout_batch_size=128
n_samples_per_prompts=16
episode=300
temperature=1.0
batch_size=1024
lr=1e-6
kl_loss_coef=0.0
kl_coef=0.0
entropy_coeff=0
max_gen_length=3072
numcall=1
dataset_name=torl_data
run_name=rl.grpo_qwen.math.1.5b_${dataset_name}_numcall${numcall}
data_path=../data/${dataset_name}
samples_save_path=../data/samples/$run_name

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$data_path/train.parquet \
    data.val_files=$data_path/test.parquet \
    data.train_batch_size=$rollout_batch_size \
    data.val_batch_size=2048 \
    data.max_prompt_length=400 \
    data.max_response_length=$max_gen_length \
    data.template_type=tir_base_0309 \
    actor_rollout_ref.model.path=$policy_path \
    actor_rollout_ref.actor.optim.lr=$lr \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=$batch_size \
    actor_rollout_ref.actor.use_dynamic_bsz=True \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=$entropy_coeff \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=True \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.num_llm_calls_available=$numcall \
    actor_rollout_ref.rollout.temperature=$temperature \
    actor_rollout_ref.rollout.top_p=1.0 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
    actor_rollout_ref.rollout.n=$n_samples_per_prompts \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
    actor_rollout_ref.rollout.max_num_seqs=1024 \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.kl_ctrl.kl_coef=$kl_coef \
    trainer.logger=['console','wandb'] \
    trainer.project_name=torl \
    trainer.experiment_name=${run_name} \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=4 \
    trainer.default_local_dir=path_to_save/${run_name} \
    trainer.resume_mode=auto \
    trainer.samples_save_path=$samples_save_path \
    trainer.total_epochs=$episode $@

================================================
FILE: verl/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))

with open(os.path.join(version_folder, 'version/version')) as f:
    __version__ = f.read().strip()

from .protocol import DataProto

from .utils.logging_utils import set_basic_config
import logging

set_basic_config(level=logging.WARNING)

from . import single_controller

__all__ = ['DataProto', "__version__"]

if os.getenv('VERL_USE_MODELSCOPE', 'False').lower() == 'true':
    import importlib
    if importlib.util.find_spec("modelscope") is None:
        raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
    # Patch hub to download models from modelscope to speed up.
    from modelscope.utils.hf_util import patch_hub
    patch_hub()


================================================
FILE: verl/models/README.md
================================================
# Models
Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. 
## Adding a New Huggingface Model
### Step 1: Copy the model file from HF to verl
- Add a new file under verl/models/hf
- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf

### Step 2: Modify the model file to use packed inputs
- Remove all the code related to inference (kv cache)
- Modify the inputs to include only
    - input_ids (total_nnz,)
    - cu_seqlens (total_nnz + 1,)
    - max_seqlen_in_batch: int
- Note that this requires using flash attention with causal mask.

### Step 2.5: Add tests
- Add a test to compare this version and the huggingface version
- Following the infrastructure and add tests to tests/models/hf

### Step 3: Add a function to apply tensor parallelism
- Please follow
    - https://pytorch.org/docs/stable/distributed.tensor.parallel.html
    - https://pytorch.org/tutorials/intermediate/TP_tutorial.html
- General comments
    - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward.

### Step 4: Add a function to apply data parallelism
- Please use FSDP2 APIs
- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413

### Step 5: Add a function to apply pipeline parallelism
- Comes in Pytorch 2.4
- Currently only in alpha in nightly version
- Check torchtitan for more details



================================================
FILE: verl/models/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


================================================
FILE: verl/models/llama/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


================================================
FILE: verl/models/llama/megatron/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .modeling_llama_megatron import (
    # original model with megatron
    ParallelLlamaModel,
    ParallelLlamaForCausalLM,
    # rmpad with megatron
    ParallelLlamaForCausalLMRmPad,
    ParallelLlamaForValueRmPad,
    # rmpad with megatron and pipeline parallelism
    ParallelLlamaForCausalLMRmPadPP,
    ParallelLlamaForValueRmPadPP)


================================================
FILE: verl/models/llama/megatron/checkpoint_utils/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


================================================
FILE: verl/models/llama/megatron/checkpoint_utils/llama_loader.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from packaging.version import Version
import torch
import time
from typing import Dict, Any, Callable, Optional
import torch.distributed as dist


def _megatron_calc_layer_map(config):
    """Calculate the mapping of global layer_idx to local layer_idx
    Returns:
        layer_map (Dict: int -> tuple(int, int, int)):
            mapping from the global layer index to
            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
    """
    import megatron
    from megatron.core import mpu

    pp_size = mpu.get_pipeline_model_parallel_world_size()
    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1

    layer_map = dict()
    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers

    for pp_rank_idx in range(pp_size):
        for virtual_pp_rank_idx in range(virtual_pp_size):
            layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
                            pp_rank_idx * num_layers_per_model)
            for layer_idx in range(num_layers_per_model):
                layer_map[layer_offset + layer_idx] = (
                    pp_rank_idx,
                    virtual_pp_rank_idx,
                    layer_idx,
                )
    return layer_map


def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False):
    """Load merged state_dict to sharded Megatron module in training.
    """
    import megatron
    from megatron.core import mpu
    from megatron.training.utils import print_rank_0, unwrap_model
    from megatron.core.transformer.module import Float16Module
    from megatron.core import DistributedDataParallel as LocalDDP
    from torch.nn.parallel import DistributedDataParallel as torchDDP

    start_time = time.time()

    def _get_gpt_model(model):
        return model

    def broadcast_params(module):
        for param in module.parameters():
            torch.distributed.broadcast(param.data,
                                        src=mpu.get_data_parallel_src_rank(),
                                        group=mpu.get_data_parallel_group())

    dp_rank = mpu.get_data_parallel_rank()
    pp_rank = mpu.get_pipeline_model_parallel_rank()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
    mp_group = mpu.get_model_parallel_group()

    if torch.distributed.get_rank() == 0:
        assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
        assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
        assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"

    if not isinstance(wrapped_models, (list, tuple)):
        wrapped_models = list(wrapped_models)

    assert len(wrapped_models) == virtual_pp_size
    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers

    models = [None] * len(wrapped_models)

    for i, wrapped_model in enumerate(wrapped_models):
        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
        gpt_model_module = _get_gpt_model(models[i])
        assert len(gpt_model_module.model.layers) == num_layers_per_model

    def _broadcast_tensor(tensor, name) -> torch.Tensor:
        """broadcast tensor from rank0 across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        if torch.distributed.get_rank() == 0:
            if name in state_dict:
                weight = state_dict[name]
                tensor_shape = weight.shape
            else:
                tensor_shape = None
        else:
            weight = None
            tensor_shape = None

        obj_list = [tensor_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        tensor_shape = obj_list[0]

        if tensor_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
            return

        if tensor is None:
            tensor = torch.empty(
                tensor_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        if torch.distributed.get_rank() == 0:
            tensor.data.copy_(weight)
        dist.broadcast(tensor, src=0, group=mp_group)

    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()

        if torch.distributed.get_rank() == 0:
            if name in state_dict:
                full_weight = state_dict[name]

                if mutate_func is not None:
                    full_weight = mutate_func(full_weight)
                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
                chunk_shape = tensor_chunk[0].shape
            else:
                chunk_shape = None
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
            return

        if tensor is None:
            sync_tensor = torch.empty(
                chunk_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        else:
            assert (tensor.shape == chunk_shape
                   ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
            sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)

        for i in range(tp_size):
            if torch.distributed.get_rank() == 0:
                sync_tensor.data.copy_(tensor_chunk[i])
            dist.broadcast(sync_tensor, src=0, group=mp_group)
            if (i == tp_rank) and (tensor is not None):
                tensor.data.copy_(sync_tensor)

    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()

        if torch.distributed.get_rank() == 0:
            if name in state_dict:
                full_weight = state_dict[name]
                if mutate_func is not None:
                    full_weight = mutate_func(full_weight)
                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
                chunk_shape = tensor_chunk[0].shape
            else:
                chunk_shape = None
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
            return

        if tensor is None:
            sync_tensor = torch.empty(
                chunk_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        else:
            assert (tensor.shape == chunk_shape
                   ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
            sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)

        for i in range(tp_size):
            if torch.distributed.get_rank() == 0:
                sync_tensor.data.copy_(tensor_chunk[i])
            dist.broadcast(sync_tensor, src=0, group=mp_group)
            if (i == tp_rank) and (tensor is not None):
                tensor.data.copy_(sync_tensor)

    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()

        if torch.distributed.get_rank() == 0:
            gate_weight = state_dict[gate_name]
            up_weight = state_dict[up_name]
            new_gate_up_weight = torch.empty(config.intermediate_size * 2,
                                             config.hidden_size,
                                             dtype=params_dtype,
                                             device=torch.cuda.current_device())
            for i in range(tp_size):
                intermediate_size_tp = config.intermediate_size // tp_size
                gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
                up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
                new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
                    torch.cat([gate_weight_tp, up_weight_tp], dim=0))

            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
            chunk_shape = tensor_chunk[0].shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
            return

        if tensor is None:
            sync_tensor = torch.empty(
                chunk_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        else:
            assert (
                tensor.shape == chunk_shape
            ), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
            sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)

        for i in range(tp_size):
            if torch.distributed.get_rank() == 0:
                sync_tensor.data.copy_(tensor_chunk[i])
            dist.broadcast(sync_tensor, src=0, group=mp_group)
            if (i == tp_rank) and (tensor is not None):
                tensor.data.copy_(sync_tensor)

    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()

        if torch.distributed.get_rank() == 0:
            assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
            full_weight_q = state_dict[q_name]
            full_weight_k = state_dict[k_name]
            full_weight_v = state_dict[v_name]

            hidden_size_per_head = config.hidden_size // config.num_attention_heads

            if config.num_key_value_heads >= tp_size:
                q_size_tp = config.hidden_size // tp_size
                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
                total_size = q_size_tp + 2 * kv_size_tp
                new_weight_qkv = torch.empty(total_size * tp_size,
                                             config.hidden_size,
                                             dtype=params_dtype,
                                             device=torch.cuda.current_device())
                for i in range(tp_size):
                    q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
                    k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
                    v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
                    new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
                                                                                        dim=0))

            else:
                q_size_tp = config.hidden_size // tp_size
                kv_size_tp = hidden_size_per_head
                total_size = q_size_tp + 2 * kv_size_tp
                new_weight_qkv = torch.empty(total_size * tp_size,
                                             config.hidden_size,
                                             dtype=params_dtype,
                                             device=torch.cuda.current_device())
                for i in range(tp_size):
                    q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
                    k_part = full_weight_k[start_idx:end_idx]
                    v_part = full_weight_v[start_idx:end_idx]
                    new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
                                                                                        dim=0))

            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
            chunk_shape = tensor_chunk[0].shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
            return

        if tensor is None:
            sync_tensor = torch.empty(
                chunk_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        else:
            assert (tensor.shape == chunk_shape
                   ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
            sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)

        for i in range(tp_size):
            if torch.distributed.get_rank() == 0:
                sync_tensor.data.copy_(tensor_chunk[i])
            dist.broadcast(sync_tensor, src=0, group=mp_group)
            if (i == tp_rank) and (tensor is not None):
                tensor.data.copy_(sync_tensor)

    if dp_rank == 0:
        # Embeddings
        # -------------------
        print_rank_0("loading embeddings...")
        gpt_model_module = _get_gpt_model(models[0])
        embed_tokens_weight = None
        if pp_rank == 0:
            embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")

        # Transformer layers
        # -------------------
        layer_map = _megatron_calc_layer_map(config)

        for layer in range(config.num_hidden_layers):
            print_rank_0(f"loading layer #{layer}...")
            layer_name = f"model.layers.{layer}"
            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]

            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
            sync_layer = gpt_model_module.model.layers[dst_layer_idx]

            _broadcast_tensor(
                sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.input_layernorm.weight",
            )

            _broadcast_tp_shard_tensor_qkv(
                sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.self_attn.q_proj.weight",
                f"{layer_name}.self_attn.k_proj.weight",
                f"{layer_name}.self_attn.v_proj.weight",
            )

            _broadcast_tp_shard_tensor(
                sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.self_attn.o_proj.weight",
                chunk_dim=1,
            )

            _broadcast_tensor(
                sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.post_attention_layernorm.weight",
            )

            _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
                                               f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")

            _broadcast_tp_shard_tensor(
                sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.mlp.down_proj.weight",
                chunk_dim=1,
            )
        # Final Layernorm
        # -------------------
        print_rank_0("loading final layernorm...")
        gpt_model_module = _get_gpt_model(models[-1])
        _broadcast_tensor(
            getattr(gpt_model_module.model.norm, "weight", None),
            "model.norm.weight",
        )

        print_rank_0("loading lm_head...")
        lm_head_weight = None
        if pp_rank + 1 == pp_size:
            lm_head_weight = gpt_model_module.lm_head.weight

        if is_value_model:
            # if torch.distributed.get_rank() == 0:
            if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
                _broadcast_tensor(lm_head_weight, "lm_head.weight")
            elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
                _broadcast_tensor(lm_head_weight, "reward_head.weight")
                print_rank_0('load lm_head from value_head weight')
            else:
                _broadcast_tensor(None, "lm_head.weight")
                print_rank_0('fail to match lm_head in value_model')
            # else:

            #     _broadcast_tensor(lm_head_weight, "lm_head.weight")

        else:
            _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
    dist.barrier()
    # Broadcast weights inside data parallel groups
    for wrapped_model in wrapped_models:
        broadcast_params(wrapped_model)

    torch.cuda.empty_cache()
    print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")


================================================
FILE: verl/models/llama/megatron/checkpoint_utils/llama_saver.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from packaging.version import Version
from torch.nn.parallel import DistributedDataParallel as torchDDP
import torch
import time
from typing import Optional
import torch.distributed as dist

import megatron
from megatron import get_args
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallel as LocalDDP

from megatron.training.utils import print_rank_0, unwrap_model


def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
    """given TP,DP,PP rank to get the global rank."""

    args = get_args()
    tp_size = mpu.get_tensor_model_parallel_world_size()
    dp_size = mpu.get_data_parallel_world_size()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size()
           ), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
    if args.switch_dp_and_pp_grouping:
        # TP-PP-DP grouping
        return (dp_rank * pp_size + pp_rank) * tp_size + tp_rank
    else:
        # TP-DP-PP grouping
        return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank


def _megatron_calc_layer_map(config):
    """Calculate the mapping of global layer_idx to local layer_idx
    Returns:
        layer_map (Dict: int -> tuple(int, int, int)):
            mapping from the global layer index to
            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
    """
    import megatron
    from megatron.core import mpu

    pp_size = mpu.get_pipeline_model_parallel_world_size()
    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1

    args = megatron.get_args()
    layer_map = dict()
    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers

    for pp_rank_idx in range(pp_size):
        for virtual_pp_rank_idx in range(virtual_pp_size):
            layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
                            pp_rank_idx * num_layers_per_model)
            for layer_idx in range(num_layers_per_model):
                layer_map[layer_offset + layer_idx] = (
                    pp_rank_idx,
                    virtual_pp_rank_idx,
                    layer_idx,
                )
    return layer_map


def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtype='bf16'):
    """Merge sharded parameters of a Megatron module into a merged checkpoint.

    Args:
        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
            The local DDP wrapped megatron modules.
        dtype (str or None):
            The data type of state_dict. if None, the data type of the original parameters
            is used.
        gpt_model_key: key to access model
    Returns:
        state_dict (dict):
            The merged state_dict in rank 0, and an empty dictionary in other ranks.
    """
    start_time = time.time()
    args = megatron.get_args()

    def _get_gpt_model(model):
        return model

    dp_rank = mpu.get_data_parallel_rank()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    pp_rank = mpu.get_pipeline_model_parallel_rank()
    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
    mp_group = mpu.get_model_parallel_group()

    if dist.get_rank() == 0:
        assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
        assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
        assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"

    if not isinstance(wrapped_models, (list, tuple)):
        wrapped_models = list(wrapped_models)

    assert len(wrapped_models) == virtual_pp_size
    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers

    models = [None] * len(wrapped_models)

    for i, wrapped_model in enumerate(wrapped_models):
        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
        assert len(models[i].model.layers
                  ) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format(
                      len(models[i].model.layers), num_layers_per_model)

    state_dict = dict()

    def _get_cpu_tensor(tensor: torch.Tensor):
        if tensor is None:
            return None
        if tensor.device == torch.device("cpu"):
            return tensor.detach().clone()
        return tensor.detach().cpu()

    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
        """broadcast tensor across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)

        if torch.distributed.get_rank() == src_rank:
            if tensor is None:
                weight = None
                tensor_shape = None
            else:
                weight = tensor
                tensor_shape = weight.shape
        else:
            weight = None
            tensor_shape = None

        obj_list = [tensor_shape]
        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
        tensor_shape = obj_list[0]

        if tensor_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tensor:[{name}] not exist, skip collect")
            return

        if weight is None:
            weight = torch.empty(
                tensor_shape,
                dtype=args.params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )

        dist.broadcast(weight, src=src_rank, group=mp_group)

        if torch.distributed.get_rank() == 0:
            state_dict[name] = _get_cpu_tensor(weight)

    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()
        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)

        if torch.distributed.get_rank() == src_rank:
            chunk_shape = tensor.shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
            return

        buffer_tensor = torch.empty(
            chunk_shape,
            dtype=args.params_dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )

        chunk_tensors = [None] * tp_size

        for i in range(tp_size):
            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)

            if torch.distributed.get_rank() == 0:
                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)

        if torch.distributed.get_rank() == 0:
            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
            if mutate_func is not None:
                full_tensor = mutate_func(full_tensor)
            state_dict[name] = full_tensor

    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()
        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)

        if torch.distributed.get_rank() == src_rank:
            chunk_shape = tensor.shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
            return

        buffer_tensor = torch.empty(
            chunk_shape,
            dtype=args.params_dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )

        chunk_tensors = [None] * tp_size

        for i in range(tp_size):
            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)

            if torch.distributed.get_rank() == 0:
                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)

        if torch.distributed.get_rank() == 0:
            full_tensor = torch.concat(chunk_tensors, dim=0)
            intermediate_size_tp = config.intermediate_size // tp_size
            gate_weight_list = []
            up_weight_list = []
            for i in range(tp_size):
                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)]
                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
                gate_weight_list.append(gate_weight_tp)
                up_weight_list.append(up_weight_tp)

            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
            state_dict[up_name] = torch.cat(up_weight_list, dim=0)

    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()
        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)

        if torch.distributed.get_rank() == src_rank:
            chunk_shape = tensor.shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
            return

        buffer_tensor = torch.empty(
            chunk_shape,
            dtype=args.params_dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )

        chunk_tensors = [None] * tp_size

        for i in range(tp_size):
            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)

            if torch.distributed.get_rank() == 0:
                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)

        if torch.distributed.get_rank() == 0:
            full_tensor = torch.concat(chunk_tensors, dim=0)
            q_weight_list = []
            k_weight_list = []
            v_weight_list = []
            hidden_size_per_head = config.hidden_size // config.num_attention_heads

            if config.num_key_value_heads >= tp_size:
                q_size_tp = config.hidden_size // tp_size
                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
                total_size = q_size_tp + 2 * kv_size_tp
                for i in range(tp_size):
                    qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
                    q_part = qkv_part[:q_size_tp]
                    k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
                    v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
                    q_weight_list.append(q_part)
                    k_weight_list.append(k_part)
                    v_weight_list.append(v_part)
            else:
                q_size_tp = config.hidden_size // tp_size
                kv_size_tp = hidden_size_per_head
                total_size = q_size_tp + 2 * kv_size_tp
                for i in range(tp_size):
                    qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
                    q_part = qkv_part[:q_size_tp]
                    k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
                    v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
                    q_weight_list.append(q_part)
                    if i * config.num_key_value_heads % tp_size == 0:
                        k_weight_list.append(k_part)
                        v_weight_list.append(v_part)

            state_dict[q_name] = torch.cat(q_weight_list, dim=0)
            state_dict[k_name] = torch.cat(k_weight_list, dim=0)
            state_dict[v_name] = torch.cat(v_weight_list, dim=0)

    # empty cache before collecting weights
    torch.cuda.empty_cache()
    # Embeddings
    # -------------------
    if dp_rank == 0:
        # Embeddings
        # -------------------
        print_rank_0("collecting embeddings...")
        gpt_model_module = _get_gpt_model(models[0])
        _broadcast_tp_shard_tensor(
            gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
            "model.embed_tokens.weight",
            src_pp_rank=0,
        )

        # Transformer layers
        # -------------------
        layer_map = _megatron_calc_layer_map(config)
        for layer in range(config.num_hidden_layers):
            print_rank_0(f"collecting layer #{layer}...")
            layer_name = f"model.layers.{layer}"
            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]

            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
            sync_layer = gpt_model_module.model.layers[src_layer_idx]

            _broadcast_tensor(
                sync_layer.input_layernorm.weight,
                f"{layer_name}.input_layernorm.weight",
                src_pp_rank=src_pp_rank,
            )

            _broadcast_tp_shard_tensor_qkv(
                sync_layer.self_attn.qkv_proj.weight,
                f"{layer_name}.self_attn.q_proj.weight",
                f"{layer_name}.self_attn.k_proj.weight",
                f"{layer_name}.self_attn.v_proj.weight",
                src_pp_rank=src_pp_rank,
            )

            _broadcast_tp_shard_tensor(
                sync_layer.self_attn.o_proj.weight,
                f"{layer_name}.self_attn.o_proj.weight",
                concat_dim=1,
                src_pp_rank=src_pp_rank,
            )

            _broadcast_tensor(
                sync_layer.post_attention_layernorm.weight,
                f"{layer_name}.post_attention_layernorm.weight",
                src_pp_rank=src_pp_rank,
            )

            _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight,
                                               f"{layer_name}.mlp.gate_proj.weight",
                                               f"{layer_name}.mlp.up_proj.weight",
                                               src_pp_rank=src_pp_rank)

            _broadcast_tp_shard_tensor(
                sync_layer.mlp.down_proj.weight,
                f"{layer_name}.mlp.down_proj.weight",
                concat_dim=1,
                src_pp_rank=src_pp_rank,
            )

        # Final Layernorm
        # -------------------
        print_rank_0("collecting final layernorm...")
        gpt_model_module = _get_gpt_model(models[-1])
        _broadcast_tensor(
            getattr(gpt_model_module.model.norm, "weight", None),
            "model.norm.weight",
            src_pp_rank=pp_size - 1,
        )

        print_rank_0("collecting lm_head...")

        if is_value_model:
            _broadcast_tensor(getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
                              "reward_head.weight",
                              src_pp_rank=pp_size - 1)

        else:
            _broadcast_tp_shard_tensor(
                getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
                "lm_head.weight",
                src_pp_rank=pp_size - 1,
            )

    dist.barrier()

    torch.cuda.empty_cache()
    if torch.distributed.get_rank() == 0:
        if dtype == "fp16":
            dtype = torch.float16
        elif dtype == "bf16":
            dtype = torch.bfloat16
        elif dtype is None or dtype == "fp32":
            dtype = torch.float32
        else:
            print(f'Unknown/unsupported dtype to save: {dtype}"')
            exit(1)
        for k, v in state_dict.items():
            if dtype != v.dtype:
                state_dict[k] = v.to(dtype)

    print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
    return state_dict


================================================
FILE: verl/models/llama/megatron/layers/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .parallel_attention import ParallelLlamaAttention
from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad
from .parallel_mlp import ParallelLlamaMLP
from .parallel_rmsnorm import ParallelLlamaRMSNorm


================================================
FILE: verl/models/llama/megatron/layers/parallel_attention.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Tuple

import torch
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import LlamaConfig
from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear

from verl.utils.megatron import tensor_parallel as tp_utils


class LlamaRotaryEmbedding(nn.Module):

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(seq_len=max_position_embeddings,
                                device=self.inv_freq.device,
                                dtype=torch.get_default_dtype())

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        t = t / self.scaling_factor

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len

        if seq_len > self.max_position_embeddings:
            base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
                                (self.scaling_factor - 1))**(self.dim / (self.dim - 2))
            inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class ParallelLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
        super().__init__()
        self.config = config
        self.megatron_config = megatron_config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta

        # assign values after tp
        tp_size = mpu.get_tensor_model_parallel_world_size()
        assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}'
        assert self.num_key_value_heads % tp_size == 0, \
            f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}'

        self.num_heads_per_tp = self.num_heads // tp_size
        self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size
        self.hidden_size_per_tp = self.hidden_size // tp_size

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {self.num_heads}).")

        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()

        if megatron_config is not None:
            assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
            assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)

        # [self.q_size, self.k_size, self.v_size]
        self.qkv_proj = QKVParallelLinear(input_size=self.hidden_size,
                                          num_heads=self.num_heads,
                                          num_key_value_heads=self.num_key_value_heads,
                                          head_dim=self.head_dim,
                                          bias=config.attention_bias,
                                          gather_output=False,
                                          skip_bias_add=False,
                                          **column_kwargs)

        self.q_size = self.num_heads_per_tp * self.head_dim
        self.k_size = self.num_key_value_heads_per_tp * self.head_dim
        self.v_size = self.num_key_value_heads_per_tp * self.head_dim

        self.o_proj = tensor_parallel.RowParallelLinear(input_size=self.num_heads * self.head_dim,
                                                        output_size=self.hidden_size,
                                                        bias=config.attention_bias,
                                                        input_is_parallel=True,
                                                        skip_bias_add=False,
                                                        **row_kwargs)

        self._init_rope()

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = LlamaRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()
        qkv = self.qkv_proj(hidden_states)[0]
        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)

        query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}")

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}")

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)
        attn_output = self.o_proj(attn_output)[0]
        return attn_output


"""
Remove padding Attention
- Using Flash-attn 2
- Compatible with sequence parallel
"""

from transformers.utils import is_flash_attn_2_available
import torch.nn.functional as F

from einops import rearrange

if is_flash_attn_2_available():
    from flash_attn import flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa


def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):
    batch_size = position_ids.shape[0]

    q = pad_input(q, indices, batch_size, sequence_length)  # (batch_size, seqlen, num_head, head_dim)
    k = pad_input(k, indices, batch_size, sequence_length)
    cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
    sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices)
    k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices)

    return q_embed, k_embed


from flash_attn.layers.rotary import apply_rotary_emb


# use flash-attn rotary embeddings with rmpad
# cos/sin shoudl be: (seq_length, rotary_dim / 2)
def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):
    q_embed = apply_rotary_emb(q,
                               cos,
                               sin,
                               interleaved=False,
                               inplace=False,
                               cu_seqlens=cu_seqlens,
                               max_seqlen=max_seqlen)
    k_embed = apply_rotary_emb(k,
                               cos,
                               sin,
                               interleaved=False,
                               inplace=False,
                               cu_seqlens=cu_seqlens,
                               max_seqlen=max_seqlen)
    return q_embed, k_embed


class ParallelLlamaAttentionRmPad(ParallelLlamaAttention):

    def forward(self,
                hidden_states: torch.Tensor,
                position_ids: Optional[torch.LongTensor] = None,
                sequence_length: int = None,
                indices: torch.Tensor = None,
                cu_seqlens: torch.Tensor = None,
                max_seqlen_in_batch: int = None):
        total_nnz, _, _ = hidden_states.size()  # This is the total_nnz padded after sequence parallel

        if self.megatron_config.sequence_parallel:
            total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()

        qkv = self.qkv_proj(hidden_states)[0]
        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size],
                                                           dim=-1)  # (total_nnz, 1, hidden_size)

        if self.megatron_config.sequence_parallel:
            sequence_parallel_pad = total_nnz - cu_seqlens[-1]
            total_nnz = cu_seqlens[-1]  # total_nnz before sp padding
            query_states = query_states[:total_nnz]
            key_states = key_states[:total_nnz]
            value_states = value_states[:total_nnz]

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dime x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)
        key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
        value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)

        cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)
        cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2]  # flash attn only needs half
        query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states,
                                                                    key_states,
                                                                    cos,
                                                                    sin,
                                                                    cu_seqlens=cu_seqlens,
                                                                    max_seqlen=max_seqlen_in_batch)
        # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,

        # TODO: llama does not have dropout in the config??
        # It is recommended to use dropout with FA according to the docs
        # when training.
        dropout_rate = 0.0  # if not self.training else self.attn_dropout

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in float16 just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (LlamaRMSNorm handles it correctly)
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            query_states = query_states.to(torch.float16)
            key_states = key_states.to(torch.float16)
            value_states = value_states.to(torch.float16)

        attn_output_unpad = flash_attn_varlen_func(
            query_states,
            key_states,
            value_states,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen_in_batch,
            max_seqlen_k=max_seqlen_in_batch,
            dropout_p=dropout_rate,
            softmax_scale=None,
            causal=True,
        )

        attn_output_unpad = attn_output_unpad.to(input_dtype)
        attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()

        # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled
        # Here we need to repad
        if self.megatron_config.sequence_parallel:
            attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))

        attn_output_unpad = self.o_proj(attn_output_unpad)[0]
        return attn_output_unpad


================================================
FILE: verl/models/llama/megatron/layers/parallel_decoder.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import torch
from torch import nn
from transformers import LlamaConfig
from megatron.core import ModelParallelConfig

from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad
from .parallel_mlp import ParallelLlamaMLP
from .parallel_rmsnorm import ParallelLlamaRMSNorm


class ParallelLlamaDecoderLayer(nn.Module):

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config)

        self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
        self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
        self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Note: sequence parallel is hidden inside ColumnParallelLinear
        # reduce scatter is hidden inside RowParallelLinear

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        # TODO: add sequence parallel operator reduce_scatter here

        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)

        # TODO: add sequence parallel operator all_gather here

        hidden_states = self.mlp(hidden_states)

        # TODO: add sequence parallel operator reduce_scatter here

        hidden_states = residual + hidden_states

        outputs = hidden_states

        return outputs


class ParallelLlamaDecoderLayerRmPad(nn.Module):

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
        super().__init__()
        self.config = config
        self.megatron_config = megatron_config
        self.hidden_size = config.hidden_size
        self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config)

        self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
        self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
        self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
        sequence_length: int = None,
        indices: torch.Tensor = None,
        cu_seqlens: int = None,
        max_seqlen_in_batch: int = None
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states  # (total_nnz // sp, 1, hidden_size)

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)
        # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       position_ids=position_ids,
                                       sequence_length=sequence_length,
                                       indices=indices,
                                       cu_seqlens=cu_seqlens,
                                       max_seqlen_in_batch=max_seqlen_in_batch)

        hidden_states = residual + hidden_states

        # Fully Connected
        # shape changes same as attn
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = hidden_states

        return outputs


================================================
FILE: verl/models/llama/megatron/layers/parallel_linear.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py

from typing import Optional, Tuple

from megatron.core import tensor_parallel


class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):

    def __init__(self,
                 input_size,
                 num_heads,
                 num_key_value_heads,
                 head_dim,
                 *,
                 bias=True,
                 gather_output=True,
                 skip_bias_add=False,
                 **kwargs):
        # Keep input parameters, and already restrict the head numbers
        self.input_size = input_size
        self.q_output_size = num_heads * head_dim
        self.kv_output_size = num_key_value_heads * head_dim
        self.head_dim = head_dim
        self.gather_output = gather_output
        self.skip_bias_add = skip_bias_add

        input_size = self.input_size
        output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim

        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         **kwargs)


class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):

    def __init__(self,
                 input_size,
                 gate_ouput_size,
                 up_output_size,
                 *,
                 bias=True,
                 gather_output=True,
                 skip_bias_add=False,
                 **kwargs):
        # Keep input parameters, and already restrict the head numbers
        self.input_size = input_size
        self.output_size = gate_ouput_size + up_output_size
        self.gather_output = gather_output
        self.skip_bias_add = skip_bias_add

        super().__init__(input_size=self.input_size,
                         output_size=self.output_size,
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         **kwargs)


================================================
FILE: verl/models/llama/megatron/layers/parallel_mlp.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers.activations import ACT2FN
from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear

from verl.utils.megatron import tensor_parallel as tp_utils


class ParallelLlamaMLP(nn.Module):

    def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]

        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()

        if megatron_config is not None:
            assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
            assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)

        tp_size = mpu.get_tensor_model_parallel_world_size()

        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=self.hidden_size,
            gate_ouput_size=self.intermediate_size,
            up_output_size=self.intermediate_size,
            bias=False,
            gather_output=False,
            skip_bias_add=False,
            **column_kwargs,
        )
        self.gate_size = self.intermediate_size // tp_size

        self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
                                                           output_size=self.hidden_size,
                                                           bias=False,
                                                           input_is_parallel=True,
                                                           skip_bias_add=False,
                                                           **row_kwargs)

        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        gate_up = self.gate_up_proj(x)[0]
        gate, up = gate_up.split(self.gate_size, dim=-1)
        return self.down_proj(self.act_fn(gate) * up)[0]


================================================
FILE: verl/models/llama/megatron/layers/parallel_rmsnorm.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers
import torch
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import LlamaConfig

from apex.normalization.fused_layer_norm import fused_rms_norm_affine
from verl.utils.megatron import sequence_parallel as sp_utils


class ParallelLlamaRMSNorm(nn.Module):

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        if isinstance(config.hidden_size, numbers.Integral):
            normalized_shape = (config.hidden_size,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.weight = nn.Parameter(torch.ones(self.normalized_shape))
        self.variance_epsilon = config.rms_norm_eps

        if megatron_config.sequence_parallel:
            sp_utils.mark_parameter_as_sequence_parallel(self.weight)

    def forward(self, hidden_states):
        return fused_rms_norm_affine(input=hidden_states,
                                     weight=self.weight,
                                     normalized_shape=self.normalized_shape,
                                     eps=self.variance_epsilon,
                                     memory_efficient=True)

================================================
FILE: verl/models/llama/megatron/modeling_llama_megatron.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model with Megatron-style acceleration."""

from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import CausalLMOutputWithPast

from verl.utils.megatron import sequence_parallel as sp_utils
from verl.utils.megatron import tensor_parallel as tp_utils
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad
"""
TODO: 
1. Add weight initialization. Here we need to be careful on TP weight init.
2. Add sequence parallel
3. Load checkpoint from meta LLama pretrained checkpoint
"""


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


class ParallelLlamaModel(nn.Module):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
        if megatron_config is not None:
            assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
                                                                   embedding_dim=config.hidden_size,
                                                                   **embedding_kwargs)

        self.layers = nn.ModuleList(
            [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])
        self.norm = ParallelLlamaRMSNorm(config, megatron_config)

    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
            )

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
                                              tgt_len=input_shape[-1]).to(inputs_embeds.device)
            combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
                                       combined_attention_mask)

        return combined_attention_mask

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        """

        Args:
            input_ids: input ids. shape (batch_size, seq_length)
            attention_mask: attention_mask. shape (batch_size, seq_length)
            position_ids: position ids. shape (batch_size, seq_length)

        Returns:

        """
        batch_size, seq_length = input_ids.shape
        inputs_embeds = self.embed_tokens(input_ids)
        # embed positions

        attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)

        hidden_states = inputs_embeds

        for idx, decoder_layer in enumerate(self.layers):
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
            )

            hidden_states = layer_outputs

        hidden_states = self.norm(hidden_states)

        return hidden_states


class ParallelLlamaForCausalLM(nn.Module):

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
        super().__init__()
        self.model = ParallelLlamaModel(config, megatron_config=megatron_config)
        self.vocab_size = config.vocab_size

        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
        if megatron_config is not None:
            assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)

        self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
                                                            output_size=config.vocab_size,
                                                            bias=False,
                                                            gather_output=False,
                                                            skip_bias_add=False,
                                                            **column_kwargs)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
        ```"""

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        hidden_states = outputs
        logits = self.lm_head(hidden_states)[0]

        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)

        logits = logits.float()
        return CausalLMOutputWithPast(
            loss=None,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )


from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa


class ParallelLlamaModelRmPad(nn.Module):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
        self.megatron_config = megatron_config
        if megatron_config is not None:
            assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
                                                                   embedding_dim=config.hidden_size,
                                                                   **embedding_kwargs)

        self.layers = nn.ModuleList(
            [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])
        self.norm = ParallelLlamaRMSNorm(config, megatron_config)

    def forward(self,
                input_ids: torch.Tensor,
                position_ids: Optional[torch.LongTensor] = None,
                sequence_length: int = None,
                indices: torch.Tensor = None,
                cu_seqlens: int = None,
                max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
        """

        Args:
            input_ids: input ids. shape (1, totol_nnz)
            position_ids: position ids. shape (batch_size, seq_length)

        Returns:

        """
        inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)

        # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
        inputs_embeds = inputs_embeds.transpose(0, 1)
        if self.megatron_config.sequence_parallel:
            inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)

        hidden_states = inputs_embeds
        for idx, decoder_layer in enumerate(self.layers):
            layer_outputs = decoder_layer(hidden_states,
                                          position_ids=position_ids,
                                          sequence_length=sequence_length,
                                          indices=indices,
                                          cu_seqlens=cu_seqlens,
                                          max_seqlen_in_batch=max_seqlen_in_batch)

            hidden_states = layer_outputs

        hidden_states = self.norm(hidden_states)

        return hidden_states


class ParallelLlamaForCausalLMRmPad(nn.Module):

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
        super().__init__()
        self.config = config
        self.megatron_config = megatron_config
        self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)
        self.vocab_size = config.vocab_size
        self._init_head()

    def _init_head(self):
        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
        if self.megatron_config is not None:
            assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
        self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
                                                            output_size=self.config.vocab_size,
                                                            bias=False,
                                                            gather_output=False,
                                                            skip_bias_add=False,
                                                            **column_kwargs)

    def _forward_head(self, hidden_states):
        # all_gather from sequence parallel region is performed inside lm_head
        logits = self.lm_head(hidden_states)[0]
        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)
        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)  # (total_nnz_padded, 1, vocab_size)
        return logits

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
        ```"""
        batch_size, sequence_length = input_ids.shape

        # remove padding here
        input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
                                                                              attention_mask)  # (total_nnz, 1)

        # pad input_ids to multiple of tp for all tp ranks
        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
        if self.megatron_config.sequence_parallel:
            input_ids = sp_utils.pad_to_sequence_parallel(input_ids)

        input_ids = input_ids.transpose(0, 1)  # (1, total_nnz+pad)

        outputs = self.model(input_ids=input_ids,
                             position_ids=position_ids,
                             sequence_length=sequence_length,
                             indices=indices,
                             cu_seqlens=cu_seqlens,
                             max_seqlen_in_batch=max_seqlen_in_batch)

        hidden_states = outputs

        logits = self._forward_head(hidden_states)

        # remove padding from sequence parallel
        if self.megatron_config.sequence_parallel:
            totol_nnz = cu_seqlens[-1]
            logits = logits[:totol_nnz]  # (total_nnz_padded)

        logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension
        # add removed padding back
        logits = pad_input(logits, indices, batch_size,
                           seqlen=sequence_length)  # (batch_size, sequence_length, vocab_size)

        return CausalLMOutputWithPast(
            loss=None,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )


class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):

    def _init_head(self):
        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
        if self.megatron_config is not None:
            assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
        self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
        # lm_head is effectively the same as sequence parallel
        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)

    def _forward_head(self, hidden_states):
        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)
        logits = logits.float()
        if self.megatron_config.sequence_parallel:
            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
        return logits

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output = super().forward(input_ids, attention_mask, position_ids)
        output.logits = torch.squeeze(output.logits, dim=-1)
        return output


"""
Support pipeline parallelism
"""


class ParallelLlamaModelRmPadPP(nn.Module):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
    This model definition supports pipeline parallelism. To support pp and vpp,
    - This model only contains layer in this pp stage and vpp chunk
    - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.
    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.pre_process = pre_process
        self.post_process = post_process
        self.megatron_config = megatron_config
        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
        if megatron_config is not None:
            assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
        if pre_process:
            self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
                                                                       embedding_dim=config.hidden_size,
                                                                       **embedding_kwargs)
        else:
            self.embed_tokens = None

        # pp_rank = megatron_config.pipeline_model_parallel_rank
        pp_size = megatron_config.pipeline_model_parallel_size
        self.num_layer_per_pp = config.num_hidden_layers // pp_size
        vpp_size = megatron_config.virtual_pipeline_model_parallel_size

        if vpp_size is not None:
            self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size
            self.num_layer_this_model = self.num_layer_vpp_chunk
            # vpp_rank = megatron_config.virtual_pipeline_model_parallel_rank
            # self.offset = vpp_rank * (
            #         config.num_hidden_layers // megatron_config.virtual_pipeline_model_parallel_size) + \
            #             (megatron_config.pipeline_model_parallel_rank * self.num_layer_vpp_chunk)
        else:
            self.num_layer_this_model = self.num_layer_per_pp
            # self.offset = pp_rank * self.num_layer_per_pp

        layers = []
        for i in range(self.num_layer_this_model):
            layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config)
            # setattr(layer, 'hidden_layer_index', self.offset + i)
            layers.append(layer)

        self.layers = nn.ModuleList(layers)

        if post_process:
            self.norm = ParallelLlamaRMSNorm(config, megatron_config)
        else:
            self.norm = None

    def set_input_tensor(self, input_tensor):
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
        self.input_tensor = input_tensor

    def forward(self,
                input_ids: torch.Tensor,
                position_ids: Optional[torch.LongTensor] = None,
                sequence_length: int = None,
                indices: torch.Tensor = None,
                cu_seqlens: int = None,
                max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
        """

        Args:
            input_ids: input ids. shape (1, totol_nnz)
            position_ids: position ids. shape (batch_size, seq_length)

        Returns:

        """
        if self.pre_process:
            inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)

            # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron
            # so need to deal with it by handle here:
            # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
            inputs_embeds = inputs_embeds.transpose(0, 1)
            if self.megatron_config.sequence_parallel:
                inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)

            hidden_states = inputs_embeds
        else:
            # self.hidden_states should be passed by Megatron
            hidden_states = self.input_tensor

        for idx, decoder_layer in enumerate(self.layers):
            layer_outputs = decoder_layer(hidden_states,
                                          position_ids=position_ids,
                                          sequence_length=sequence_length,
                                          indices=indices,
                                          cu_seqlens=cu_seqlens,
                                          max_seqlen_in_batch=max_seqlen_in_batch)

            hidden_states = layer_outputs

        if self.post_process:
            hidden_states = self.norm(hidden_states)

        return hidden_states


class ParallelLlamaForCausalLMRmPadPP(nn.Module):

    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process,
                 share_embeddings_and_output_weights):
        super().__init__()
        self.config = config
        self.megatron_config = megatron_config
        self.model = ParallelLlamaModelRmPadPP(config,
                                               megatron_config=megatron_config,
                                               pre_process=pre_process,
                                               post_process=post_process)
        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
        self.vocab_size = config.vocab_size
        self.pre_process = pre_process
        self.post_process = post_process
        if post_process:
            self._init_head()

    def set_input_tensor(self, input_tensor):
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
        assert len(input_tensor) == 1
        self.model.set_input_tensor(input_tensor[0])

    def _init_head(self):
        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
        if self.megatron_config is not None:
            assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
        self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.config.hidden_size,
                                                            output_size=self.config.vocab_size,
                                                            bias=False,
                                                            gather_output=False,
                                                            skip_bias_add=False,
                                                            **column_kwargs)

    def _forward_head(self, hidden_states):
        # all_gather from sequence parallel region is performed inside lm_head
        # logits shape before forward_head hidden_states.shape: [4, 32, 4096]
        logits = self.lm_head(hidden_states)[0]
        # logits shape after forward_head logits.shape: [8, 32, 8]
        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)
        return logits

    def forward(
        self,
        # original input
        *,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
        ```"""

        # Note that input_ids, attention_mask and position_ids should be passed to every pp layer.
        # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
        batch_size, sequence_length = input_ids.shape
        # remove padding here
        input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
                                                                                    attention_mask)  # (total_nnz, 1)

        # pad input_ids to multiple of tp for all tp ranks
        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
        if self.megatron_config.sequence_parallel:
            input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)

        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz+pad)

        outputs = self.model(input_ids=input_ids_rmpad,
                             position_ids=position_ids,
                             sequence_length=sequence_length,
                             indices=indices,
                             cu_seqlens=cu_seqlens,
                             max_seqlen_in_batch=max_seqlen_in_batch)

        if self.post_process:
            hidden_states = outputs
            # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
            logits = self._forward_head(hidden_states)
            logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension # torch.Size([8, 32, 16])

            # remove padding from sequence parallel
            if self.megatron_config.sequence_parallel:
                totol_nnz = cu_seqlens[-1]
                logits = logits[:totol_nnz]  # (total_nnz_padded)
            # add removed padding back. If input is already rmpad, we let the caller pad_input
            logits = pad_input(logits, indices, batch_size,
                               seqlen=sequence_length)  # (batch_size, sequence_length, vocab_size)

            return CausalLMOutputWithPast(
                loss=None,
                logits=logits,
                past_key_values=None,
                hidden_states=None,
                attentions=None,
            )
        else:
            return outputs


class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):

    def _init_head(self):
        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
        if self.megatron_config is not None:
            assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
        self.lm_head = nn.Linear(in_features=self.config.hidden_size, out_features=1, bias=False)
        # lm_head is effectively the same as sequence parallel
        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)

    def _forward_head(self, hidden_states):
        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)
        logits = logits.float()
        if self.megatron_config.sequence_parallel:
            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
        return logits

    def forward(
        self,
        *,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
        if self.post_process:
            output.logits = torch.squeeze(output.logits, dim=-1)
            return output
        else:
            return output


================================================
FILE: verl/models/qwen2/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


================================================
FILE: verl/models/qwen2/megatron/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .modeling_qwen2_megatron import (
    # original model with megatron
    ParallelQwen2Model,
    ParallelQwen2ForCausalLM,
    # rmpad with megatron
    ParallelQwen2ForCausalLMRmPad,
    ParallelQwen2ForValueRmPad,
    # rmpad with megatron and pipeline parallelism
    ParallelQwen2ForCausalLMRmPadPP,
    ParallelQwen2ForValueRmPadPP)


================================================
FILE: verl/models/qwen2/megatron/checkpoint_utils/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


================================================
FILE: verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import time
from typing import Dict, Any, Callable, Optional
import torch.distributed as dist


def _megatron_calc_layer_map(config):
    """Calculate the mapping of global layer_idx to local layer_idx
    Returns:
        layer_map (Dict: int -> tuple(int, int, int)):
            mapping from the global layer index to
            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
    """
    import megatron
    from megatron.core import mpu

    pp_size = mpu.get_pipeline_model_parallel_world_size()
    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1

    layer_map = dict()
    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers

    for pp_rank_idx in range(pp_size):
        for virtual_pp_rank_idx in range(virtual_pp_size):
            layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
                            pp_rank_idx * num_layers_per_model)
            for layer_idx in range(num_layers_per_model):
                layer_map[layer_offset + layer_idx] = (
                    pp_rank_idx,
                    virtual_pp_rank_idx,
                    layer_idx,
                )
    return layer_map


def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False):
    """Load merged state_dict to sharded Megatron module in training.
    """
    import megatron
    from megatron.core import mpu
    from megatron.training.utils import print_rank_0, unwrap_model
    from megatron.core.transformer.module import Float16Module
    from megatron.core import DistributedDataParallel as LocalDDP
    from torch.nn.parallel import DistributedDataParallel as torchDDP

    start_time = time.time()

    def _get_gpt_model(model):
        return model

    def broadcast_params(module):
        for param in module.parameters():
            torch.distributed.broadcast(param.data,
                                        src=mpu.get_data_parallel_src_rank(),
                                        group=mpu.get_data_parallel_group())

    dp_rank = mpu.get_data_parallel_rank()
    pp_rank = mpu.get_pipeline_model_parallel_rank()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
    mp_group = mpu.get_model_parallel_group()

    if torch.distributed.get_rank() == 0:
        assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
        assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
        assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"

    if not isinstance(wrapped_models, (list, tuple)):
        wrapped_models = list(wrapped_models)

    assert len(wrapped_models) == virtual_pp_size
    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers

    models = [None] * len(wrapped_models)

    for i, wrapped_model in enumerate(wrapped_models):
        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
        gpt_model_module = _get_gpt_model(models[i])
        assert len(gpt_model_module.model.layers) == num_layers_per_model

    def _broadcast_tensor(tensor, name) -> torch.Tensor:
        """broadcast tensor from rank0 across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        if torch.distributed.get_rank() == 0:
            if name in state_dict:
                weight = state_dict[name]
                tensor_shape = weight.shape
            else:
                tensor_shape = None
        else:
            weight = None
            tensor_shape = None

        obj_list = [tensor_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        tensor_shape = obj_list[0]

        if tensor_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
            return

        if tensor is None:
            tensor = torch.empty(
                tensor_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        if torch.distributed.get_rank() == 0:
            tensor.data.copy_(weight)
        dist.broadcast(tensor, src=0, group=mp_group)

    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()

        if torch.distributed.get_rank() == 0:
            if name in state_dict:
                full_weight = state_dict[name]

                if mutate_func is not None:
                    full_weight = mutate_func(full_weight)
                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
                chunk_shape = tensor_chunk[0].shape
            else:
                chunk_shape = None
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
            return

        if tensor is None:
            sync_tensor = torch.empty(
                chunk_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        else:
            assert (tensor.shape == chunk_shape
                   ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
            sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)

        for i in range(tp_size):
            if torch.distributed.get_rank() == 0:
                sync_tensor.data.copy_(tensor_chunk[i])
            dist.broadcast(sync_tensor, src=0, group=mp_group)
            if (i == tp_rank) and (tensor is not None):
                tensor.data.copy_(sync_tensor)

    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()

        if torch.distributed.get_rank() == 0:
            if name in state_dict:
                full_weight = state_dict[name]
                if mutate_func is not None:
                    full_weight = mutate_func(full_weight)
                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
                chunk_shape = tensor_chunk[0].shape
            else:
                chunk_shape = None
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
            return

        if tensor is None:
            sync_tensor = torch.empty(
                chunk_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        else:
            assert (tensor.shape == chunk_shape
                   ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
            sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)

        for i in range(tp_size):
            if torch.distributed.get_rank() == 0:
                sync_tensor.data.copy_(tensor_chunk[i])
            dist.broadcast(sync_tensor, src=0, group=mp_group)
            if (i == tp_rank) and (tensor is not None):
                tensor.data.copy_(sync_tensor)

    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()

        if torch.distributed.get_rank() == 0:
            gate_weight = state_dict[gate_name]
            up_weight = state_dict[up_name]
            new_gate_up_weight = torch.empty(config.intermediate_size * 2,
                                             config.hidden_size,
                                             dtype=params_dtype,
                                             device=torch.cuda.current_device())
            for i in range(tp_size):
                intermediate_size_tp = config.intermediate_size // tp_size
                gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
                up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
                new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
                    torch.cat([gate_weight_tp, up_weight_tp], dim=0))

            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
            chunk_shape = tensor_chunk[0].shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
            return

        if tensor is None:
            sync_tensor = torch.empty(
                chunk_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        else:
            assert (
                tensor.shape == chunk_shape
            ), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
            sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)

        for i in range(tp_size):
            if torch.distributed.get_rank() == 0:
                sync_tensor.data.copy_(tensor_chunk[i])
            dist.broadcast(sync_tensor, src=0, group=mp_group)
            if (i == tp_rank) and (tensor is not None):
                tensor.data.copy_(sync_tensor)

    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()

        if torch.distributed.get_rank() == 0:
            assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
            full_weight_q = state_dict[q_name]
            full_weight_k = state_dict[k_name]
            full_weight_v = state_dict[v_name]

            hidden_size_per_head = config.hidden_size // config.num_attention_heads

            if config.num_key_value_heads >= tp_size:
                q_size_tp = config.hidden_size // tp_size
                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
                total_size = q_size_tp + 2 * kv_size_tp
                if not bias:
                    new_weight_qkv = torch.empty(total_size * tp_size,
                                                 config.hidden_size,
                                                 dtype=params_dtype,
                                                 device=torch.cuda.current_device())
                else:
                    new_weight_qkv = torch.empty(total_size * tp_size,
                                                 dtype=params_dtype,
                                                 device=torch.cuda.current_device())
                for i in range(tp_size):
                    q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
                    k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
                    v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
                    new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
                                                                                        dim=0))

            else:
                q_size_tp = config.hidden_size // tp_size
                kv_size_tp = hidden_size_per_head
                total_size = q_size_tp + 2 * kv_size_tp
                if not bias:
                    new_weight_qkv = torch.empty(total_size * tp_size,
                                                 config.hidden_size,
                                                 dtype=params_dtype,
                                                 device=torch.cuda.current_device())
                else:
                    new_weight_qkv = torch.empty(total_size * tp_size,
                                                 dtype=params_dtype,
                                                 device=torch.cuda.current_device())
                for i in range(tp_size):
                    q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
                    k_part = full_weight_k[start_idx:end_idx]
                    v_part = full_weight_v[start_idx:end_idx]
                    new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
                                                                                        dim=0))

            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
            chunk_shape = tensor_chunk[0].shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=0, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
            return

        if tensor is None:
            sync_tensor = torch.empty(
                chunk_shape,
                dtype=params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )
        else:
            assert (tensor.shape == chunk_shape
                   ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
            sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)

        for i in range(tp_size):
            if torch.distributed.get_rank() == 0:
                sync_tensor.data.copy_(tensor_chunk[i])
            dist.broadcast(sync_tensor, src=0, group=mp_group)
            if (i == tp_rank) and (tensor is not None):
                tensor.data.copy_(sync_tensor)

    if dp_rank == 0:
        # Embeddings
        # -------------------
        print_rank_0("loading embeddings...")
        gpt_model_module = _get_gpt_model(models[0])
        embed_tokens_weight = None
        if pp_rank == 0:
            embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")

        # Transformer layers
        # -------------------
        layer_map = _megatron_calc_layer_map(config)

        for layer in range(config.num_hidden_layers):
            print_rank_0(f"loading layer #{layer}...")
            layer_name = f"model.layers.{layer}"
            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]

            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
            sync_layer = gpt_model_module.model.layers[dst_layer_idx]

            _broadcast_tensor(
                sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.input_layernorm.weight",
            )

            _broadcast_tp_shard_tensor_qkv(
                sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.self_attn.q_proj.weight",
                f"{layer_name}.self_attn.k_proj.weight",
                f"{layer_name}.self_attn.v_proj.weight",
            )

            _broadcast_tp_shard_tensor_qkv(sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,
                                           f"{layer_name}.self_attn.q_proj.bias",
                                           f"{layer_name}.self_attn.k_proj.bias",
                                           f"{layer_name}.self_attn.v_proj.bias",
                                           bias=True)

            _broadcast_tp_shard_tensor(
                sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.self_attn.o_proj.weight",
                chunk_dim=1,
            )

            _broadcast_tensor(
                sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.post_attention_layernorm.weight",
            )

            _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
                                               f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")

            _broadcast_tp_shard_tensor(
                sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
                f"{layer_name}.mlp.down_proj.weight",
                chunk_dim=1,
            )
        # Final Layernorm
        # -------------------
        print_rank_0("loading final layernorm...")
        gpt_model_module = _get_gpt_model(models[-1])
        _broadcast_tensor(
            getattr(gpt_model_module.model.norm, "weight", None),
            "model.norm.weight",
        )

        print_rank_0("loading lm_head...")
        lm_head_weight = None
        if pp_rank + 1 == pp_size:
            lm_head_weight = gpt_model_module.lm_head.weight

        if is_value_model:
            # if torch.distributed.get_rank() == 0:
            if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
                _broadcast_tensor(lm_head_weight, "lm_head.weight")
            elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
                _broadcast_tensor(lm_head_weight, "reward_head.weight")
                print_rank_0('load lm_head from value_head weight')
            else:
                _broadcast_tensor(None, "lm_head.weight")
                print_rank_0('fail to match lm_head in value_model')
            # else:

            #     _broadcast_tensor(lm_head_weight, "lm_head.weight")

        else:
            _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
    dist.barrier()
    # Broadcast weights inside data parallel groups
    for wrapped_model in wrapped_models:
        broadcast_params(wrapped_model)

    torch.cuda.empty_cache()
    print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")


================================================
FILE: verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import megatron
from megatron.core import mpu
from megatron.training.utils import print_rank_0, unwrap_model
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
import torch
import time
from typing import Optional
import torch.distributed as dist
from megatron import get_args


def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
    """given TP,DP,PP rank to get the global rank."""

    args = get_args()
    tp_size = mpu.get_tensor_model_parallel_world_size()
    dp_size = mpu.get_data_parallel_world_size()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size()
           ), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
    if args.switch_dp_and_pp_grouping:
        # TP-PP-DP grouping
        return (dp_rank * pp_size + pp_rank) * tp_size + tp_rank
    else:
        # TP-DP-PP grouping
        return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank


def _megatron_calc_layer_map(config):
    """Calculate the mapping of global layer_idx to local layer_idx
    Returns:
        layer_map (Dict: int -> tuple(int, int, int)):
            mapping from the global layer index to
            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
    """
    import megatron
    from megatron.core import mpu

    pp_size = mpu.get_pipeline_model_parallel_world_size()
    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1

    args = megatron.get_args()
    layer_map = dict()
    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers

    for pp_rank_idx in range(pp_size):
        for virtual_pp_rank_idx in range(virtual_pp_size):
            layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
                            pp_rank_idx * num_layers_per_model)
            for layer_idx in range(num_layers_per_model):
                layer_map[layer_offset + layer_idx] = (
                    pp_rank_idx,
                    virtual_pp_rank_idx,
                    layer_idx,
                )
    return layer_map


def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtype='bf16'):
    """Merge sharded parameters of a Megatron module into a merged checkpoint.

    Args:
        wrapped_modelss (list of megatron.core.distributed.DistributedDataParallel):
            The local DDP wrapped megatron modules.
        dtype (str or None):
            The data type of state_dict. if None, the data type of the original parameters
            is used.
        gpt_model_key: key to access model
    Returns:
        state_dict (dict):
            The merged state_dict in rank 0, and an empty dictionary in other ranks.
    """
    start_time = time.time()
    args = megatron.get_args()

    def _get_gpt_model(model):
        return model

    dp_rank = mpu.get_data_parallel_rank()
    pp_size = mpu.get_pipeline_model_parallel_world_size()
    pp_rank = mpu.get_pipeline_model_parallel_rank()
    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
    mp_group = mpu.get_model_parallel_group()

    if dist.get_rank() == 0:
        assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
        assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
        assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"

    if not isinstance(wrapped_models, (list, tuple)):
        wrapped_models = list(wrapped_models)

    assert len(wrapped_models) == virtual_pp_size
    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers

    models = [None] * len(wrapped_models)

    for i, wrapped_model in enumerate(wrapped_models):
        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
        assert len(models[i].model.layers
                  ) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format(
                      len(models[i].model.layers), num_layers_per_model)

    state_dict = dict()

    def _get_cpu_tensor(tensor: torch.Tensor):
        if tensor is None:
            return None
        if tensor.device == torch.device("cpu"):
            return tensor.detach().clone()
        return tensor.detach().cpu()

    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
        """broadcast tensor across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)

        if torch.distributed.get_rank() == src_rank:
            if tensor is None:
                weight = None
                tensor_shape = None
            else:
                weight = tensor
                tensor_shape = weight.shape
        else:
            weight = None
            tensor_shape = None

        obj_list = [tensor_shape]
        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
        tensor_shape = obj_list[0]

        if tensor_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tensor:[{name}] not exist, skip collect")
            return

        if weight is None:
            weight = torch.empty(
                tensor_shape,
                dtype=args.params_dtype,
                device=torch.cuda.current_device(),
                requires_grad=False,
            )

        dist.broadcast(weight, src=src_rank, group=mp_group)

        if torch.distributed.get_rank() == 0:
            state_dict[name] = _get_cpu_tensor(weight)

    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()
        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)

        if torch.distributed.get_rank() == src_rank:
            chunk_shape = tensor.shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
            return

        buffer_tensor = torch.empty(
            chunk_shape,
            dtype=args.params_dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )

        chunk_tensors = [None] * tp_size

        for i in range(tp_size):
            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)

            if torch.distributed.get_rank() == 0:
                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)

        if torch.distributed.get_rank() == 0:
            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
            if mutate_func is not None:
                full_tensor = mutate_func(full_tensor)
            state_dict[name] = full_tensor

    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()
        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)

        if torch.distributed.get_rank() == src_rank:
            chunk_shape = tensor.shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
            return

        buffer_tensor = torch.empty(
            chunk_shape,
            dtype=args.params_dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )

        chunk_tensors = [None] * tp_size

        for i in range(tp_size):
            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)

            if torch.distributed.get_rank() == 0:
                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)

        if torch.distributed.get_rank() == 0:
            full_tensor = torch.concat(chunk_tensors, dim=0)
            intermediate_size_tp = config.intermediate_size // tp_size
            gate_weight_list = []
            up_weight_list = []
            for i in range(tp_size):
                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)]
                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
                gate_weight_list.append(gate_weight_tp)
                up_weight_list.append(up_weight_tp)

            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
            state_dict[up_name] = torch.cat(up_weight_list, dim=0)

    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
        """broadcast tensor in tp shards across mp_group"""
        nonlocal state_dict
        nonlocal mp_group
        tp_rank = mpu.get_tensor_model_parallel_rank()
        tp_size = mpu.get_tensor_model_parallel_world_size()
        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)

        if torch.distributed.get_rank() == src_rank:
            chunk_shape = tensor.shape
        else:
            chunk_shape = None

        obj_list = [chunk_shape]
        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
        chunk_shape = obj_list[0]
        if chunk_shape is None:
            # all or none ranks in the mp_group should reach here
            print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
            return

        buffer_tensor = torch.empty(
            chunk_shape,
            dtype=args.params_dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )

        chunk_tensors = [None] * tp_size

        for i in range(tp_size):
            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)

            if torch.distributed.get_rank() == 0:
                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)

        if torch.distributed.get_rank() == 0:
            full_tensor = torch.concat(chunk_tensors, dim=0)
            q_weight_list = []
            k_weight_list = []
            v_weight_list = []
            hidden_size_per_head = config.hidden_size // config.num_attention_heads

            if config.num_key_value_heads >= tp_size:
                q_size_tp = config.hidden_size // tp_size
                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
                total_size = q_size_tp + 2 * kv_size_tp
                for i in range(tp_size):
                    qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
                    q_part = qkv_part[:q_size_tp]
                    k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
                    v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
                    q_weight_list.append(q_part)
                    k_weight_list.append(k_part)
                    v_weight_list.append(v_part)
            else:
                q_size_tp = config.hidden_size // tp_size
                kv_size_tp = hidden_size_per_head
                total_size = q_size_tp + 2 * kv_size_tp
                for i in range(tp_size):
                    qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
                    q_part = qkv_part[:q_size_tp]
                    k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
                    v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
                    q_weight_list.append(q_part)
                    if i * config.num_key_value_heads % tp_size == 0:
                        k_weight_list.append(k_part)
                        v_weight_list.append(v_part)

            state_dict[q_name] = torch.cat(q_weight_list, dim=0)
            state_dict[k_name] = torch.cat(k_weight_list, dim=0)
            state_dict[v_name] = torch.cat(v_weight_list, dim=0)

    # empty cache before collecting weights
    torch.cuda.empty_cache()
    # Embeddings
    # -------------------
    if dp_rank == 0:
        # Embeddings
        # -------------------
        print_rank_0("collecting embeddings...")
        gpt_model_module = _get_gpt_model(models[0])
        _broadcast_tp_shard_tensor(
            gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
            "model.embed_tokens.weight",
            src_pp_rank=0,
        )

        # Transformer layers
        # -------------------
        layer_map = _megatron_calc_layer_map(config)
        for layer in range(config.num_hidden_layers):
            print_rank_0(f"collecting layer #{layer}...")
            layer_name = f"model.layers.{layer}"
            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]

            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
            sync_layer = gpt_model_module.model.layers[src_layer_idx]

            _broadcast_tensor(
                sync_layer.input_layernorm.weight,
                f"{layer_name}.input_layernorm.weight",
                src_pp_rank=src_pp_rank,
            )

            _broadcast_tp_shard_tensor_qkv(
                sync_layer.self_attn.qkv_proj.weight,
                f"{layer_name}.self_attn.q_proj.weight",
                f"{layer_name}.self_attn.k_proj.weight",
                f"{layer_name}.self_attn.v_proj.weight",
                src_pp_rank=src_pp_rank,
            )

            _broadcast_tp_shard_tensor(
                sync_layer.self_attn.o_proj.weight,
                f"{layer_name}.self_attn.o_proj.weight",
                concat_dim=1,
                src_pp_rank=src_pp_rank,
            )

            _broadcast_tensor(
                sync_layer.post_attention_layernorm.weight,
                f"{layer_name}.post_attention_layernorm.weight",
                src_pp_rank=src_pp_rank,
            )

            _broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight,
                                               f"{layer_name}.mlp.gate_proj.weight",
                                               f"{layer_name}.mlp.up_proj.weight",
                                               src_pp_rank=src_pp_rank)

            _broadcast_tp_shard_tensor(
                sync_layer.mlp.down_proj.weight,
                f"{layer_name}.mlp.down_proj.weight",
                concat_dim=1,
                src_pp_rank=src_pp_rank,
            )

        # Final Layernorm
        # -------------------
        print_rank_0("collecting final layernorm...")
        gpt_model_module = _get_gpt_model(models[-1])
        _broadcast_tensor(
            getattr(gpt_model_module.model.norm, "weight", None),
            "model.norm.weight",
            src_pp_rank=pp_size - 1,
        )

        print_rank_0("collecting lm_head...")

        if is_value_model:
            _broadcast_tensor(getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
                              "reward_head.weight",
                              src_pp_rank=pp_size - 1)

        else:
            _broadcast_tp_shard_tensor(
                getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
                "lm_head.weight",
                src_pp_rank=pp_size - 1,
            )

    dist.barrier()

    torch.cuda.empty_cache()
    if torch.distributed.get_rank() == 0:
        if dtype == "fp16":
            dtype = torch.float16
        elif dtype == "bf16":
            dtype = torch.bfloat16
        elif dtype is None or dtype == "fp32":
            dtype = torch.float32
        else:
            print(f'Unknown/unsupported dtype to save: {dtype}"')
            exit(1)
        for k, v in state_dict.items():
            if dtype != v.dtype:
                state_dict[k] = v.to(dtype)

    print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
    return state_dict


================================================
FILE: verl/models/qwen2/megatron/layers/__init__.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .parallel_attention import ParallelQwen2Attention
from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad
from .parallel_mlp import ParallelQwen2MLP
from .parallel_rmsnorm import ParallelQwen2RMSNorm


================================================
FILE: verl/models/qwen2/megatron/layers/parallel_attention.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Tuple

import torch
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import Qwen2Config
from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear

from verl.utils.megatron import tensor_parallel as tp_utils


class Qwen2RotaryEmbedding(nn.Module):

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(seq_len=max_position_embeddings,
                                device=self.inv_freq.device,
                                dtype=torch.get_default_dtype())

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding):
    """Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        t = t / self.scaling_factor

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding):
    """Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len

        if seq_len > self.max_position_embeddings:
            base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
                                (self.scaling_factor - 1))**(self.dim / (self.dim - 2))
            inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class ParallelQwen2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):
        super().__init__()
        self.config = config
        self.megatron_config = megatron_config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta

        # assign values after tp
        tp_size = mpu.get_tensor_model_parallel_world_size()
        assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}'
        assert self.num_key_value_heads % tp_size == 0, \
            f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}'

        self.num_heads_per_tp = self.num_heads // tp_size
        self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size
        self.hidden_size_per_tp = self.hidden_size // tp_size

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {self.num_heads}).")

        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()

        if megatron_config is not None:
            assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
            assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)

        # [self.q_size, self.k_size, self.v_size]
        self.qkv_proj = QKVParallelLinear(
            input_size=self.hidden_size,
            num_heads=self.num_heads,
            num_key_value_heads=self.num_key_value_heads,
            head_dim=self.head_dim,
            # bias=config.attention_bias,
            bias=True,
            gather_output=False,
            skip_bias_add=False,
            **column_kwargs)

        self.q_size = self.num_heads_per_tp * self.head_dim
        self.k_size = self.num_key_value_heads_per_tp * self.head_dim
        self.v_size = self.num_key_value_heads_per_tp * self.head_dim

        self.o_proj = tensor_parallel.RowParallelLinear(
            input_size=self.num_heads * self.head_dim,
            output_size=self.hidden_size,
            # bias=config.attention_bias,
            bias=False,
            input_is_parallel=True,
            skip_bias_add=False,
            **row_kwargs)

        self._init_rope()

    def _init_rope(self):
        self.rotary_emb = Qwen2RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()
        qkv = self.qkv_proj(hidden_states)[0]
        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)

        query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}")

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}")

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)
        attn_output = self.o_proj(attn_output)[0]
        return attn_output


"""
Remove padding Attention
- Using Flash-attn 2
- Compatible with sequence parallel
"""

from transformers.utils import is_flash_attn_2_available
import torch.nn.functional as F

from einops import rearrange

if is_flash_attn_2_available():
    from flash_attn import flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa


def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):
    batch_size = position_ids.shape[0]

    q = pad_input(q, indices, batch_size, sequence_length)  # (batch_size, seqlen, num_head, head_dim)
    k = pad_input(k, indices, batch_size, sequence_length)
    cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
    sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices)
    k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices)

    return q_embed, k_embed


from flash_attn.layers.rotary import apply_rotary_emb


# use flash-attn rotary embeddings with rmpad
# cos/sin shoudl be: (seq_length, rotary_dim / 2)
def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):
    q_embed = apply_rotary_emb(q,
                               cos,
                               sin,
                               interleaved=False,
                               inplace=False,
                               cu_seqlens=cu_seqlens,
                               max_seqlen=max_seqlen)
    k_embed = apply_rotary_emb(k,
                               cos,
                               sin,
                               interleaved=False,
                               inplace=False,
                               cu_seqlens=cu_seqlens,
                               max_seqlen=max_seqlen)
    return q_embed, k_embed


class ParallelQwen2AttentionRmPad(ParallelQwen2Attention):

    def forward(self,
                hidden_states: torch.Tensor,
                position_ids: Optional[torch.LongTensor] = None,
                sequence_length: int = None,
                indices: torch.Tensor = None,
                cu_seqlens: torch.Tensor = None,
                max_seqlen_in_batch: int = None):
        total_nnz, _, _ = hidden_states.size()  # This is the total_nnz padded after sequence parallel

        if self.megatron_config.sequence_parallel:
            total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()

        qkv = self.qkv_proj(hidden_states)[0]
        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size],
                                                           dim=-1)  # (total_nnz, 1, hidden_size)

        if self.megatron_config.sequence_parallel:
            sequence_parallel_pad = total_nnz - cu_seqlens[-1]
            total_nnz = cu_seqlens[-1]  # total_nnz before sp padding
            query_states = query_states[:total_nnz]
            key_states = key_states[:total_nnz]
            value_states = value_states[:total_nnz]

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dime x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)
        key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
        value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)

        cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)
        cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2]  # flash attn only needs half
        query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states,
                                                                    key_states,
                                                                    cos,
                                                                    sin,
                                                                    cu_seqlens=cu_seqlens,
                                                                    max_seqlen=max_seqlen_in_batch)
        # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,

        # It is recommended to use dropout with FA according to the docs
        # when training.
        dropout_rate = 0.0  # if not self.training else self.attn_dropout

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in float16 just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (Qwen2RMSNorm handles it correctly)
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            query_states = query_states.to(torch.float16)
            key_states = key_states.to(torch.float16)
            value_states = value_states.to(torch.float16)

        attn_output_unpad = flash_attn_varlen_func(
            query_states,
            key_states,
            value_states,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen_in_batch,
            max_seqlen_k=max_seqlen_in_batch,
            dropout_p=dropout_rate,
            softmax_scale=None,
            causal=True,
        )

        attn_output_unpad = attn_output_unpad.to(input_dtype)
        attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()

        # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled
        # Here we need to repad
        if self.megatron_config.sequence_parallel:
            attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))

        attn_output_unpad = self.o_proj(attn_output_unpad)[0]
        return attn_output_unpad


================================================
FILE: verl/models/qwen2/megatron/layers/parallel_decoder.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import torch
from torch import nn
from transformers import Qwen2Config
from megatron.core import ModelParallelConfig

from .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad
from .parallel_mlp import ParallelQwen2MLP
from .parallel_rmsnorm import ParallelQwen2RMSNorm


class ParallelQwen2DecoderLayer(nn.Module):

    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config)

        self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)
        self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)
        self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Note: sequence parallel is hidden inside ColumnParallelLinear
        # reduce scatter is hidden inside RowParallelLinear

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        # TODO: add sequence parallel operator reduce_scatter here

        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)

        # TODO: add sequence parallel operator all_gather here

        hidden_states = self.mlp(hidden_states)

        # TODO: add sequence parallel operator reduce_scatter here

        hidden_states = residual + hidden_states

        outputs = hidden_states

        return outputs


class ParallelQwen2DecoderLayerRmPad(nn.Module):

    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):
        super().__init__()
        self.config = config
        self.megatron_config = megatron_config
        self.hidden_size = config.hidden_size
        self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config)

        self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)
        self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)
        self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
        sequence_length: int = None,
        indices: torch.Tensor = None,
        cu_seqlens: int = None,
        max_seqlen_in_batch: int = None
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states  # (total_nnz // sp, 1, hidden_size)

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)
        # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       position_ids=position_ids,
                                       sequence_length=sequence_length,
                                       indices=indices,
                                       cu_seqlens=cu_seqlens,
                                       max_seqlen_in_batch=max_seqlen_in_batch)

        hidden_states = residual + hidden_states

        # Fully Connected
        # shape changes same as attn
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = hidden_states

        return outputs


================================================
FILE: verl/models/qwen2/megatron/layers/parallel_linear.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py

from typing import Optional, Tuple

from megatron.core import tensor_parallel


class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):

    def __init__(self,
                 input_size,
                 num_heads,
                 num_key_value_heads,
                 head_dim,
                 *,
                 bias=True,
                 gather_output=True,
                 skip_bias_add=False,
                 **kwargs):
        # Keep input parameters, and already restrict the head numbers
        self.input_size = input_size
        self.q_output_size = num_heads * head_dim
        self.kv_output_size = num_key_value_heads * head_dim
        self.head_dim = head_dim
        self.gather_output = gather_output
        self.skip_bias_add = skip_bias_add

        input_size = self.input_size
        output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim

        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         **kwargs)


class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):

    def __init__(self,
                 input_size,
                 gate_ouput_size,
                 up_output_size,
                 *,
                 bias=True,
                 gather_output=True,
                 skip_bias_add=False,
                 **kwargs):
        # Keep input parameters, and already restrict the head numbers
        self.input_size = input_size
        self.output_size = gate_ouput_size + up_output_size
        self.gather_output = gather_output
        self.skip_bias_add = skip_bias_add

        super().__init__(input_size=self.input_size,
                         output_size=self.output_size,
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         **kwargs)


================================================
FILE: verl/models/qwen2/megatron/layers/parallel_mlp.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers.activations import ACT2FN
from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear

from verl.utils.megatron import tensor_parallel as tp_utils


class ParallelQwen2MLP(nn.Module):

    def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]

        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()

        if megatron_config is not None:
            assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
            assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)

        tp_size = mpu.get_tensor_model_parallel_world_size()

        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=self.hidden_size,
            gate_ouput_size=self.intermediate_size,
            up_output_size=self.intermediate_size,
            bias=False,
            gather_output=False,
            skip_bias_add=False,
            **column_kwargs,
        )
        self.gate_size = self.intermediate_size // tp_size

        self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
                                                           output_size=self.hidden_size,
                                                           bias=False,
                                                           input_is_parallel=True,
                                                           skip_bias_add=False,
                                                           **row_kwargs)

        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        gate_up = self.gate_up_proj(x)[0]
        gate, up = gate_up.split(self.gate_size, dim=-1)
        return self.down_proj(self.act_fn(gate) * up)[0]


================================================
FILE: verl/models/qwen2/megatron/layers/parallel_rmsnorm.py
================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers
import torch
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import Qwen2Config

from apex.normalization.fused_layer_norm import fused_rms_norm_affine
from verl.utils.megatron import sequence_parallel as sp_utils


class ParallelQwen2RMSNorm(nn.Module):

    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        if isinstance(config.hidden_size, numbers.Integral):
            normalized_shape = (config.hidden_size,)
        self.no
Download .txt
gitextract_g7clz9ov/

├── data/
│   └── torl_data/
│       ├── test.parquet
│       ├── test_0524.parquet
│       └── train.parquet
├── readme.md
├── requirements.txt
├── scripts/
│   └── torl_1.5b.sh
└── verl/
    ├── __init__.py
    ├── models/
    │   ├── README.md
    │   ├── __init__.py
    │   ├── llama/
    │   │   ├── __init__.py
    │   │   └── megatron/
    │   │       ├── __init__.py
    │   │       ├── checkpoint_utils/
    │   │       │   ├── __init__.py
    │   │       │   ├── llama_loader.py
    │   │       │   └── llama_saver.py
    │   │       ├── layers/
    │   │       │   ├── __init__.py
    │   │       │   ├── parallel_attention.py
    │   │       │   ├── parallel_decoder.py
    │   │       │   ├── parallel_linear.py
    │   │       │   ├── parallel_mlp.py
    │   │       │   └── parallel_rmsnorm.py
    │   │       └── modeling_llama_megatron.py
    │   ├── qwen2/
    │   │   ├── __init__.py
    │   │   └── megatron/
    │   │       ├── __init__.py
    │   │       ├── checkpoint_utils/
    │   │       │   ├── __init__.py
    │   │       │   ├── qwen2_loader.py
    │   │       │   └── qwen2_saver.py
    │   │       ├── layers/
    │   │       │   ├── __init__.py
    │   │       │   ├── parallel_attention.py
    │   │       │   ├── parallel_decoder.py
    │   │       │   ├── parallel_linear.py
    │   │       │   ├── parallel_mlp.py
    │   │       │   └── parallel_rmsnorm.py
    │   │       └── modeling_qwen2_megatron.py
    │   ├── registry.py
    │   ├── transformers/
    │   │   ├── __init__.py
    │   │   ├── llama.py
    │   │   ├── monkey_patch.py
    │   │   ├── qwen2.py
    │   │   └── qwen2_vl.py
    │   └── weight_loader_registry.py
    ├── protocol.py
    ├── single_controller/
    │   ├── __init__.py
    │   ├── base/
    │   │   ├── __init__.py
    │   │   ├── decorator.py
    │   │   ├── megatron/
    │   │   │   ├── __init__.py
    │   │   │   ├── worker.py
    │   │   │   └── worker_group.py
    │   │   ├── register_center/
    │   │   │   ├── __init__.py
    │   │   │   └── ray.py
    │   │   ├── worker.py
    │   │   └── worker_group.py
    │   └── ray/
    │       ├── __init__.py
    │       ├── base.py
    │       └── megatron.py
    ├── third_party/
    │   ├── __init__.py
    │   └── vllm/
    │       ├── __init__.py
    │       ├── vllm_spmd/
    │       │   ├── __init__.py
    │       │   └── dtensor_weight_loaders.py
    │       ├── vllm_v_0_3_1/
    │       │   ├── __init__.py
    │       │   ├── arg_utils.py
    │       │   ├── config.py
    │       │   ├── llm.py
    │       │   ├── llm_engine_sp.py
    │       │   ├── model_loader.py
    │       │   ├── model_runner.py
    │       │   ├── parallel_state.py
    │       │   ├── tokenizer.py
    │       │   ├── weight_loaders.py
    │       │   └── worker.py
    │       ├── vllm_v_0_4_2/
    │       │   ├── __init__.py
    │       │   ├── arg_utils.py
    │       │   ├── config.py
    │       │   ├── dtensor_weight_loaders.py
    │       │   ├── hf_weight_loader.py
    │       │   ├── llm.py
    │       │   ├── llm_engine_sp.py
    │       │   ├── megatron_weight_loaders.py
    │       │   ├── model_loader.py
    │       │   ├── model_runner.py
    │       │   ├── parallel_state.py
    │       │   ├── spmd_gpu_executor.py
    │       │   ├── tokenizer.py
    │       │   └── worker.py
    │       ├── vllm_v_0_5_4/
    │       │   ├── __init__.py
    │       │   ├── arg_utils.py
    │       │   ├── config.py
    │       │   ├── dtensor_weight_loaders.py
    │       │   ├── hf_weight_loader.py
    │       │   ├── llm.py
    │       │   ├── llm_engine_sp.py
    │       │   ├── megatron_weight_loaders.py
    │       │   ├── model_loader.py
    │       │   ├── model_runner.py
    │       │   ├── parallel_state.py
    │       │   ├── spmd_gpu_executor.py
    │       │   ├── tokenizer.py
    │       │   └── worker.py
    │       └── vllm_v_0_6_3/
    │           ├── __init__.py
    │           ├── arg_utils.py
    │           ├── config.py
    │           ├── dtensor_weight_loaders.py
    │           ├── hf_weight_loader.py
    │           ├── llm.py
    │           ├── llm_engine_sp.py
    │           ├── megatron_weight_loaders.py
    │           ├── model_loader.py
    │           ├── model_runner.py
    │           ├── parallel_state.py
    │           ├── spmd_gpu_executor.py
    │           ├── tokenizer.py
    │           └── worker.py
    ├── trainer/
    │   ├── __init__.py
    │   ├── config/
    │   │   ├── evaluation.yaml
    │   │   ├── generation.yaml
    │   │   ├── ppo_megatron_trainer.yaml
    │   │   ├── ppo_trainer.yaml
    │   │   └── sft_trainer.yaml
    │   ├── fsdp_sft_trainer.py
    │   ├── main_eval.py
    │   ├── main_generation.py
    │   ├── main_ppo.py
    │   ├── ppo/
    │   │   ├── __init__.py
    │   │   ├── core_algos.py
    │   │   └── ray_trainer.py
    │   └── runtime_env.yaml
    ├── utils/
    │   ├── __init__.py
    │   ├── checkpoint/
    │   │   ├── __init__.py
    │   │   ├── checkpoint_manager.py
    │   │   └── fsdp_checkpoint_manager.py
    │   ├── config.py
    │   ├── dataset/
    │   │   ├── README.md
    │   │   ├── __init__.py
    │   │   ├── rl_dataset.py
    │   │   ├── rm_dataset.py
    │   │   └── sft_dataset.py
    │   ├── debug/
    │   │   ├── __init__.py
    │   │   ├── performance.py
    │   │   └── trajectory_tracker.py
    │   ├── distributed.py
    │   ├── flops_counter.py
    │   ├── fs.py
    │   ├── fsdp_utils.py
    │   ├── hdfs_io.py
    │   ├── import_utils.py
    │   ├── logger/
    │   │   ├── __init__.py
    │   │   └── aggregate_logger.py
    │   ├── logging_utils.py
    │   ├── megatron/
    │   │   ├── __init__.py
    │   │   ├── memory.py
    │   │   ├── optimizer.py
    │   │   ├── pipeline_parallel.py
    │   │   ├── sequence_parallel.py
    │   │   └── tensor_parallel.py
    │   ├── megatron_utils.py
    │   ├── memory_buffer.py
    │   ├── model.py
    │   ├── py_functional.py
    │   ├── ray_utils.py
    │   ├── rendezvous/
    │   │   ├── __init__.py
    │   │   └── ray_backend.py
    │   ├── reward_score/
    │   │   ├── __init__.py
    │   │   ├── eval.py
    │   │   ├── geo3k.py
    │   │   ├── gsm8k.py
    │   │   ├── math.py
    │   │   ├── math_verifier.py
    │   │   ├── prime_code/
    │   │   │   ├── __init__.py
    │   │   │   ├── testing_util.py
    │   │   │   └── utils.py
    │   │   └── prime_math/
    │   │       ├── __init__.py
    │   │       ├── grader.py
    │   │       └── math_normalize.py
    │   ├── seqlen_balancing.py
    │   ├── tokenizer.py
    │   ├── torch_dtypes.py
    │   ├── torch_functional.py
    │   ├── tracking.py
    │   └── ulysses.py
    ├── version/
    │   └── version
    └── workers/
        ├── __init__.py
        ├── actor/
        │   ├── __init__.py
        │   ├── base.py
        │   ├── dp_actor.py
        │   └── megatron_actor.py
        ├── critic/
        │   ├── __init__.py
        │   ├── base.py
        │   ├── dp_critic.py
        │   └── megatron_critic.py
        ├── fsdp_workers.py
        ├── megatron_workers.py
        ├── reward_manager/
        │   ├── __init__.py
        │   ├── naive.py
        │   └── prime.py
        ├── reward_model/
        │   ├── __init__.py
        │   ├── base.py
        │   └── megatron/
        │       ├── __init__.py
        │       └── reward_model.py
        ├── rollout/
        │   ├── __init__.py
        │   ├── base.py
        │   ├── hf_rollout.py
        │   ├── naive/
        │   │   ├── __init__.py
        │   │   └── naive_rollout.py
        │   ├── tokenizer.py
        │   └── vllm_rollout/
        │       ├── __init__.py
        │       ├── fire_vllm_rollout.py
        │       ├── qwen_agent/
        │       │   ├── code/
        │       │   │   ├── code_interpreter.py
        │       │   │   └── utils/
        │       │   │       └── code_utils.py
        │       │   ├── llm/
        │       │   │   └── schema.py
        │       │   ├── log.py
        │       │   ├── settings.py
        │       │   ├── tools/
        │       │   │   ├── base.py
        │       │   │   ├── code_interpreter.py
        │       │   │   └── python_executor.py
        │       │   └── utils/
        │       │       └── utils.py
        │       ├── vllm_rollout.py
        │       └── vllm_rollout_spmd.py
        └── sharding_manager/
            ├── __init__.py
            ├── base.py
            ├── fsdp_ulysses.py
            ├── fsdp_vllm.py
            └── megatron_vllm.py
Download .txt
SYMBOL INDEX (1618 symbols across 165 files)

FILE: verl/models/llama/megatron/checkpoint_utils/llama_loader.py
  function _megatron_calc_layer_map (line 23) | def _megatron_calc_layer_map(config):
  function load_state_dict_to_megatron_llama (line 53) | def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config...

FILE: verl/models/llama/megatron/checkpoint_utils/llama_saver.py
  function _megatron_calc_global_rank (line 32) | def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_ra...
  function _megatron_calc_layer_map (line 49) | def _megatron_calc_layer_map(config):
  function merge_megatron_ckpt_llama (line 80) | def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=Fal...

FILE: verl/models/llama/megatron/layers/parallel_attention.py
  class LlamaRotaryEmbedding (line 35) | class LlamaRotaryEmbedding(nn.Module):
    method __init__ (line 37) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method _set_cos_sin_cache (line 51) | def _set_cos_sin_cache(self, seq_len, device, dtype):
    method forward (line 61) | def forward(self, x, seq_len=None):
  class LlamaLinearScalingRotaryEmbedding (line 72) | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
    method __init__ (line 75) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method _set_cos_sin_cache (line 79) | def _set_cos_sin_cache(self, seq_len, device, dtype):
  class LlamaDynamicNTKScalingRotaryEmbedding (line 91) | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    method __init__ (line 94) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method _set_cos_sin_cache (line 98) | def _set_cos_sin_cache(self, seq_len, device, dtype):
  function rotate_half (line 116) | def rotate_half(x):
  function apply_rotary_pos_emb (line 123) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  function repeat_kv (line 131) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  class ParallelLlamaAttention (line 143) | class ParallelLlamaAttention(nn.Module):
    method __init__ (line 146) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method _init_rope (line 204) | def _init_rope(self):
    method _shape (line 231) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 234) | def forward(
  function apply_rotary_pos_emb_rmpad (line 299) | def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, se...
  function apply_rotary_pos_emb_rmpad_flash (line 320) | def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seq...
  class ParallelLlamaAttentionRmPad (line 338) | class ParallelLlamaAttentionRmPad(ParallelLlamaAttention):
    method forward (line 340) | def forward(self,

FILE: verl/models/llama/megatron/layers/parallel_decoder.py
  class ParallelLlamaDecoderLayer (line 33) | class ParallelLlamaDecoderLayer(nn.Module):
    method __init__ (line 35) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method forward (line 44) | def forward(
  class ParallelLlamaDecoderLayerRmPad (line 99) | class ParallelLlamaDecoderLayerRmPad(nn.Module):
    method __init__ (line 101) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method forward (line 112) | def forward(

FILE: verl/models/llama/megatron/layers/parallel_linear.py
  class QKVParallelLinear (line 21) | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
    method __init__ (line 23) | def __init__(self,
  class MergedColumnParallelLinear (line 52) | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
    method __init__ (line 54) | def __init__(self,

FILE: verl/models/llama/megatron/layers/parallel_mlp.py
  class ParallelLlamaMLP (line 31) | class ParallelLlamaMLP(nn.Module):
    method __init__ (line 33) | def __init__(self, config, megatron_config: ModelParallelConfig = None...
    method forward (line 71) | def forward(self, x):

FILE: verl/models/llama/megatron/layers/parallel_rmsnorm.py
  class ParallelLlamaRMSNorm (line 25) | class ParallelLlamaRMSNorm(nn.Module):
    method __init__ (line 27) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method forward (line 41) | def forward(self, hidden_states):

FILE: verl/models/llama/megatron/modeling_llama_megatron.py
  function _make_causal_mask (line 45) | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, d...
  function _expand_mask (line 58) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
  class ParallelLlamaModel (line 72) | class ParallelLlamaModel(nn.Module):
    method __init__ (line 80) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method _prepare_decoder_attention_mask (line 97) | def _prepare_decoder_attention_mask(self, attention_mask, input_shape,...
    method forward (line 117) | def forward(
  class ParallelLlamaForCausalLM (line 155) | class ParallelLlamaForCausalLM(nn.Module):
    method __init__ (line 157) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method forward (line 174) | def forward(
  class ParallelLlamaModelRmPad (line 215) | class ParallelLlamaModelRmPad(nn.Module):
    method __init__ (line 223) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method forward (line 240) | def forward(self,
  class ParallelLlamaForCausalLMRmPad (line 279) | class ParallelLlamaForCausalLMRmPad(nn.Module):
    method __init__ (line 281) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method _init_head (line 289) | def _init_head(self):
    method _forward_head (line 301) | def _forward_head(self, hidden_states):
    method forward (line 308) | def forward(
  class ParallelLlamaForValueRmPad (line 366) | class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
    method _init_head (line 368) | def _init_head(self):
    method _forward_head (line 377) | def _forward_head(self, hidden_states):
    method forward (line 384) | def forward(
  class ParallelLlamaModelRmPadPP (line 400) | class ParallelLlamaModelRmPadPP(nn.Module):
    method __init__ (line 410) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method set_input_tensor (line 457) | def set_input_tensor(self, input_tensor):
    method forward (line 467) | def forward(self,
  class ParallelLlamaForCausalLMRmPadPP (line 514) | class ParallelLlamaForCausalLMRmPadPP(nn.Module):
    method __init__ (line 516) | def __init__(self, config: LlamaConfig, megatron_config: ModelParallel...
    method set_input_tensor (line 532) | def set_input_tensor(self, input_tensor):
    method _init_head (line 543) | def _init_head(self):
    method _forward_head (line 555) | def _forward_head(self, hidden_states):
    method forward (line 563) | def forward(
  class ParallelLlamaForValueRmPadPP (line 627) | class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
    method _init_head (line 629) | def _init_head(self):
    method _forward_head (line 638) | def _forward_head(self, hidden_states):
    method forward (line 645) | def forward(

FILE: verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py
  function _megatron_calc_layer_map (line 21) | def _megatron_calc_layer_map(config):
  function load_state_dict_to_megatron_qwen2 (line 51) | def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config...

FILE: verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py
  function _megatron_calc_global_rank (line 28) | def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_ra...
  function _megatron_calc_layer_map (line 45) | def _megatron_calc_layer_map(config):
  function merge_megatron_ckpt_llama (line 76) | def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=Fal...

FILE: verl/models/qwen2/megatron/layers/parallel_attention.py
  class Qwen2RotaryEmbedding (line 35) | class Qwen2RotaryEmbedding(nn.Module):
    method __init__ (line 37) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method _set_cos_sin_cache (line 51) | def _set_cos_sin_cache(self, seq_len, device, dtype):
    method forward (line 61) | def forward(self, x, seq_len=None):
  class Qwen2LinearScalingRotaryEmbedding (line 72) | class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding):
    method __init__ (line 75) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method _set_cos_sin_cache (line 79) | def _set_cos_sin_cache(self, seq_len, device, dtype):
  class Qwen2DynamicNTKScalingRotaryEmbedding (line 91) | class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding):
    method __init__ (line 94) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method _set_cos_sin_cache (line 98) | def _set_cos_sin_cache(self, seq_len, device, dtype):
  function rotate_half (line 116) | def rotate_half(x):
  function apply_rotary_pos_emb (line 123) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  function repeat_kv (line 131) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  class ParallelQwen2Attention (line 143) | class ParallelQwen2Attention(nn.Module):
    method __init__ (line 146) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method _init_rope (line 208) | def _init_rope(self):
    method _shape (line 215) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 218) | def forward(
  function apply_rotary_pos_emb_rmpad (line 283) | def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, se...
  function apply_rotary_pos_emb_rmpad_flash (line 304) | def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seq...
  class ParallelQwen2AttentionRmPad (line 322) | class ParallelQwen2AttentionRmPad(ParallelQwen2Attention):
    method forward (line 324) | def forward(self,

FILE: verl/models/qwen2/megatron/layers/parallel_decoder.py
  class ParallelQwen2DecoderLayer (line 33) | class ParallelQwen2DecoderLayer(nn.Module):
    method __init__ (line 35) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method forward (line 44) | def forward(
  class ParallelQwen2DecoderLayerRmPad (line 99) | class ParallelQwen2DecoderLayerRmPad(nn.Module):
    method __init__ (line 101) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method forward (line 112) | def forward(

FILE: verl/models/qwen2/megatron/layers/parallel_linear.py
  class QKVParallelLinear (line 21) | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
    method __init__ (line 23) | def __init__(self,
  class MergedColumnParallelLinear (line 52) | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
    method __init__ (line 54) | def __init__(self,

FILE: verl/models/qwen2/megatron/layers/parallel_mlp.py
  class ParallelQwen2MLP (line 31) | class ParallelQwen2MLP(nn.Module):
    method __init__ (line 33) | def __init__(self, config, megatron_config: ModelParallelConfig = None...
    method forward (line 71) | def forward(self, x):

FILE: verl/models/qwen2/megatron/layers/parallel_rmsnorm.py
  class ParallelQwen2RMSNorm (line 25) | class ParallelQwen2RMSNorm(nn.Module):
    method __init__ (line 27) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method forward (line 41) | def forward(self, hidden_states):

FILE: verl/models/qwen2/megatron/modeling_qwen2_megatron.py
  function _make_causal_mask (line 45) | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, d...
  function _expand_mask (line 58) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
  class ParallelQwen2Model (line 72) | class ParallelQwen2Model(nn.Module):
    method __init__ (line 80) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method _prepare_decoder_attention_mask (line 97) | def _prepare_decoder_attention_mask(self, attention_mask, input_shape,...
    method forward (line 117) | def forward(
  class ParallelQwen2ForCausalLM (line 155) | class ParallelQwen2ForCausalLM(nn.Module):
    method __init__ (line 157) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method forward (line 174) | def forward(
  class ParallelQwen2ModelRmPad (line 215) | class ParallelQwen2ModelRmPad(nn.Module):
    method __init__ (line 223) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method forward (line 240) | def forward(self,
  class ParallelQwen2ForCausalLMRmPad (line 279) | class ParallelQwen2ForCausalLMRmPad(nn.Module):
    method __init__ (line 281) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method _init_head (line 289) | def _init_head(self):
    method _forward_head (line 301) | def _forward_head(self, hidden_states):
    method forward (line 308) | def forward(
  class ParallelQwen2ForValueRmPad (line 366) | class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad):
    method _init_head (line 368) | def _init_head(self):
    method _forward_head (line 377) | def _forward_head(self, hidden_states):
    method forward (line 384) | def forward(
  class ParallelQwen2ModelRmPadPP (line 400) | class ParallelQwen2ModelRmPadPP(nn.Module):
    method __init__ (line 410) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method set_input_tensor (line 457) | def set_input_tensor(self, input_tensor):
    method forward (line 467) | def forward(self,
  class ParallelQwen2ForCausalLMRmPadPP (line 514) | class ParallelQwen2ForCausalLMRmPadPP(nn.Module):
    method __init__ (line 516) | def __init__(self, config: Qwen2Config, megatron_config: ModelParallel...
    method set_input_tensor (line 534) | def set_input_tensor(self, input_tensor):
    method _init_head (line 545) | def _init_head(self):
    method setup_embeddings_and_output_layer (line 559) | def setup_embeddings_and_output_layer(self) -> None:
    method shared_embedding_or_output_weight (line 599) | def shared_embedding_or_output_weight(self) -> torch.Tensor:
    method _forward_head (line 606) | def _forward_head(self, hidden_states):
    method forward (line 617) | def forward(
  class ParallelQwen2ForValueRmPadPP (line 680) | class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP):
    method _init_head (line 682) | def _init_head(self):
    method _forward_head (line 691) | def _forward_head(self, hidden_states):
    method forward (line 698) | def forward(

FILE: verl/models/registry.py
  function check_model_support_rmpad (line 25) | def check_model_support_rmpad(model_type: str):
  class ModelRegistry (line 55) | class ModelRegistry:
    method load_model_cls (line 58) | def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.M...
    method get_supported_archs (line 74) | def get_supported_archs() -> List[str]:

FILE: verl/models/transformers/llama.py
  function llama_flash_attn_forward (line 31) | def llama_flash_attn_forward(
  function llama_attn_forward (line 154) | def llama_attn_forward(

FILE: verl/models/transformers/monkey_patch.py
  function apply_monkey_patch_to_llama (line 19) | def apply_monkey_patch_to_llama():
  function apply_monkey_patch_to_qwen2 (line 30) | def apply_monkey_patch_to_qwen2():
  function apply_monkey_patch (line 49) | def apply_monkey_patch(config: PretrainedConfig, verbose=True):
  function is_transformers_version_in_range (line 73) | def is_transformers_version_in_range(min_version: str, max_version: str)...

FILE: verl/models/transformers/qwen2.py
  function qwen2_flash_attn_forward (line 28) | def qwen2_flash_attn_forward(
  function qwen2_attn_forward (line 145) | def qwen2_attn_forward(

FILE: verl/models/transformers/qwen2_vl.py
  function get_rope_index (line 31) | def get_rope_index(
  function prepare_fa2_from_position_ids (line 134) | def prepare_fa2_from_position_ids(query: torch.Tensor, key: torch.Tensor...
  function flash_attention_forward (line 149) | def flash_attention_forward(
  function ulysses_flash_attn_forward (line 217) | def ulysses_flash_attn_forward(

FILE: verl/models/weight_loader_registry.py
  function get_weight_loader (line 16) | def get_weight_loader(arch: str):

FILE: verl/protocol.py
  function pad_dataproto_to_divisor (line 41) | def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int):
  function unpad_dataproto (line 67) | def unpad_dataproto(data: 'DataProto', pad_size):
  function union_tensor_dict (line 73) | def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict...
  function union_numpy_dict (line 87) | def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: ...
  function list_of_dict_to_dict_of_list (line 100) | def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
  function fold_batch_dim (line 112) | def fold_batch_dim(data: 'DataProto', new_batch_size):
  function unfold_batch_dim (line 132) | def unfold_batch_dim(data: 'DataProto', batch_dims=2):
  function collate_fn (line 151) | def collate_fn(x: list['DataProtoItem']):
  class DataProtoItem (line 165) | class DataProtoItem:
  class DataProto (line 173) | class DataProto:
    method __post_init__ (line 184) | def __post_init__(self):
    method __len__ (line 188) | def __len__(self):
    method __getitem__ (line 197) | def __getitem__(self, item):
    method __getstate__ (line 202) | def __getstate__(self):
    method __setstate__ (line 212) | def __setstate__(self, data):
    method save_to_disk (line 223) | def save_to_disk(self, filepath):
    method load_from_disk (line 228) | def load_from_disk(filepath) -> 'DataProto':
    method print_size (line 233) | def print_size(self, prefix=""):
    method check_consistency (line 250) | def check_consistency(self):
    method from_single_dict (line 274) | def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarr...
    method from_dict (line 289) | def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None,...
    method to (line 324) | def to(self, device) -> 'DataProto':
    method select (line 338) | def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_inf...
    method pop (line 373) | def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_k...
    method rename (line 405) | def rename(self, old_keys=None, new_keys=None) -> 'DataProto':
    method union (line 431) | def union(self, other: 'DataProto') -> 'DataProto':
    method make_iterator (line 450) | def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader...
    method chunk (line 489) | def chunk(self, chunks: int) -> List['DataProto']:
    method concat (line 522) | def concat(data: List['DataProto']) -> 'DataProto':
    method reorder (line 546) | def reorder(self, indices):
    method repeat (line 554) | def repeat(self, repeat_times=2, interleave=True):
  class DataProtoFuture (line 603) | class DataProtoFuture:
    method concat (line 620) | def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture':
    method chunk (line 624) | def chunk(self, chunks: int) -> List['DataProtoFuture']:
    method get (line 639) | def get(self):

FILE: verl/single_controller/base/decorator.py
  class Dispatch (line 25) | class Dispatch(Enum):
  class Execute (line 40) | class Execute(Enum):
  function _split_args_kwargs_data_proto (line 45) | def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
  function dispatch_one_to_all (line 60) | def dispatch_one_to_all(worker_group, *args, **kwargs):
  function dispatch_all_to_all (line 66) | def dispatch_all_to_all(worker_group, *args, **kwargs):
  function collect_all_to_all (line 70) | def collect_all_to_all(worker_group, output):
  function dispatch_megatron_compute (line 74) | def dispatch_megatron_compute(worker_group, *args, **kwargs):
  function collect_megatron_compute (line 103) | def collect_megatron_compute(worker_group, output):
  function dispatch_megatron_compute_data_proto (line 118) | def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
  function _concat_data_proto_or_future (line 129) | def _concat_data_proto_or_future(output: List):
  function collect_megatron_compute_data_proto (line 147) | def collect_megatron_compute_data_proto(worker_group, output):
  function dispatch_megatron_pp_as_dp (line 161) | def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
  function collect_megatron_pp_as_dp (line 209) | def collect_megatron_pp_as_dp(worker_group, output):
  function collect_megatron_pp_only (line 223) | def collect_megatron_pp_only(worker_group, output):
  function dispatch_megatron_pp_as_dp_data_proto (line 237) | def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
  function collect_megatron_pp_as_dp_data_proto (line 246) | def collect_megatron_pp_as_dp_data_proto(worker_group, output):
  function dispatch_dp_compute (line 255) | def dispatch_dp_compute(worker_group, *args, **kwargs):
  function collect_dp_compute (line 265) | def collect_dp_compute(worker_group, output):
  function dispatch_dp_compute_data_proto (line 272) | def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
  function dispatch_dp_compute_data_proto_with_func (line 279) | def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwar...
  function collect_dp_compute_data_proto (line 289) | def collect_dp_compute_data_proto(worker_group, output):
  function get_predefined_dispatch_fn (line 300) | def get_predefined_dispatch_fn(dispatch_mode):
  function get_predefined_execute_fn (line 350) | def get_predefined_execute_fn(execute_mode):
  function _check_dispatch_mode (line 366) | def _check_dispatch_mode(dispatch_mode):
  function _check_execute_mode (line 375) | def _check_execute_mode(execute_mode):
  function _materialize_futures (line 379) | def _materialize_futures(*args, **kwargs):
  function register (line 394) | def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL...

FILE: verl/single_controller/base/megatron/worker.py
  class MegatronWorker (line 18) | class MegatronWorker(Worker):
    method __init__ (line 20) | def __init__(self, cuda_visible_devices=None) -> None:
    method get_megatron_global_info (line 23) | def get_megatron_global_info(self):
    method get_megatron_rank_info (line 31) | def get_megatron_rank_info(self):

FILE: verl/single_controller/base/megatron/worker_group.py
  class MegatronWorkerGroup (line 21) | class MegatronWorkerGroup(WorkerGroup):
    method __init__ (line 23) | def __init__(self, resource_pool: ResourcePool, **kwargs):
    method init_megatron (line 28) | def init_megatron(self, default_megatron_kwargs: Dict = None):
    method get_megatron_rank_info (line 31) | def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
    method tp_size (line 36) | def tp_size(self):
    method dp_size (line 41) | def dp_size(self):
    method pp_size (line 46) | def pp_size(self):
    method get_megatron_global_info (line 50) | def get_megatron_global_info(self):

FILE: verl/single_controller/base/register_center/ray.py
  class WorkerGroupRegisterCenter (line 19) | class WorkerGroupRegisterCenter:
    method __init__ (line 21) | def __init__(self, rank_zero_info):
    method get_rank_zero_info (line 24) | def get_rank_zero_info(self):
  function create_worker_group_register_center (line 28) | def create_worker_group_register_center(name, info):

FILE: verl/single_controller/base/worker.py
  class DistRankInfo (line 24) | class DistRankInfo:
  class DistGlobalInfo (line 31) | class DistGlobalInfo:
  class WorkerHelper (line 37) | class WorkerHelper:
    method _get_node_ip (line 39) | def _get_node_ip(self):
    method _get_free_port (line 56) | def _get_free_port(self):
    method get_availale_master_addr_port (line 61) | def get_availale_master_addr_port(self):
    method _get_pid (line 64) | def _get_pid(self):
  class WorkerMeta (line 68) | class WorkerMeta:
    method __init__ (line 73) | def __init__(self, store) -> None:
    method to_dict (line 76) | def to_dict(self):
  class Worker (line 81) | class Worker(WorkerHelper):
    method __new__ (line 84) | def __new__(cls, *args, **kwargs):
    method _configure_before_init (line 101) | def _configure_before_init(self, register_center_name: str, rank: int):
    method __init__ (line 118) | def __init__(self, cuda_visible_devices=None) -> None:
    method _configure_with_meta (line 146) | def _configure_with_meta(self, meta: WorkerMeta):
    method get_master_addr_port (line 161) | def get_master_addr_port(self):
    method get_cuda_visible_devices (line 164) | def get_cuda_visible_devices(self):
    method world_size (line 170) | def world_size(self):
    method rank (line 174) | def rank(self):
    method execute_with_func_generator (line 178) | def execute_with_func_generator(self, func, *args, **kwargs):
    method execute_func_rank_zero (line 183) | def execute_func_rank_zero(self, func, *args, **kwargs):

FILE: verl/single_controller/base/worker_group.py
  class ResourcePool (line 26) | class ResourcePool:
    method __init__ (line 29) | def __init__(self, process_on_nodes=None, max_collocate_count: int = 1...
    method add_node (line 36) | def add_node(self, process_count):
    method world_size (line 40) | def world_size(self):
    method __call__ (line 43) | def __call__(self) -> Any:
    method store (line 47) | def store(self):
    method local_world_size_list (line 50) | def local_world_size_list(self) -> List[int]:
    method local_rank_list (line 56) | def local_rank_list(self) -> List[int]:
  class ClassWithInitArgs (line 61) | class ClassWithInitArgs:
    method __init__ (line 67) | def __init__(self, cls, *args, **kwargs) -> None:
    method __call__ (line 78) | def __call__(self) -> Any:
  function check_workers_alive (line 82) | def check_workers_alive(workers: List, is_alive: Callable, gap_time: flo...
  class WorkerGroup (line 92) | class WorkerGroup:
    method __init__ (line 95) | def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
    method _is_worker_alive (line 112) | def _is_worker_alive(self, worker):
    method _block_until_all_workers_alive (line 115) | def _block_until_all_workers_alive(self) -> None:
    method start_worker_aliveness_check (line 123) | def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
    method world_size (line 132) | def world_size(self):
    method _bind_worker_method (line 138) | def _bind_worker_method(self, user_defined_cls, func_generator):

FILE: verl/single_controller/ray/base.py
  function get_random_string (line 29) | def get_random_string(length: int) -> str:
  function func_generator (line 36) | def func_generator(self, method_name, dispatch_fn, collect_fn, execute_f...
  class RayResourcePool (line 49) | class RayResourcePool(ResourcePool):
    method __init__ (line 51) | def __init__(self,
    method get_placement_groups (line 64) | def get_placement_groups(self, strategy="STRICT_PACK", name=None):
  function extract_pg_from_exist (line 91) | def extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], sr...
  function merge_resource_pool (line 114) | def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> R...
  class RayClassWithInitArgs (line 128) | class RayClassWithInitArgs(ClassWithInitArgs):
    method __init__ (line 130) | def __init__(self, cls, *args, **kwargs) -> None:
    method set_additional_resource (line 136) | def set_additional_resource(self, additional_resource):
    method update_options (line 139) | def update_options(self, options: Dict):
    method __call__ (line 142) | def __call__(self,
  class RayWorkerGroup (line 176) | class RayWorkerGroup(WorkerGroup):
    method __init__ (line 178) | def __init__(self,
    method _is_worker_alive (line 205) | def _is_worker_alive(self, worker: ray.actor.ActorHandle):
    method _init_with_detached_workers (line 209) | def _init_with_detached_workers(self, worker_names):
    method _init_with_resource_pool (line 214) | def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, b...
    method worker_names (line 281) | def worker_names(self):
    method from_detached (line 285) | def from_detached(cls, worker_names=None, ray_cls_with_init=None):
    method spawn (line 292) | def spawn(self, prefix_set):
    method execute_rank_zero_sync (line 319) | def execute_rank_zero_sync(self, method_name: str, *args, **kwargs):
    method execute_rank_zero_async (line 322) | def execute_rank_zero_async(self, method_name: str, *args, **kwargs):
    method execute_rank_zero (line 326) | def execute_rank_zero(self, method_name: str, *args, **kwargs):
    method execute_all (line 329) | def execute_all(self, method_name: str, *args, **kwargs):
    method execute_all_sync (line 332) | def execute_all_sync(self, method_name: str, *args, **kwargs):
    method execute_all_async (line 335) | def execute_all_async(self, method_name: str, *args, **kwargs):
    method master_address (line 354) | def master_address(self):
    method master_port (line 358) | def master_port(self):
    method workers (line 362) | def workers(self):
    method world_size (line 366) | def world_size(self):
  function _bind_workers_method_to_parent (line 380) | def _bind_workers_method_to_parent(cls, key, user_defined_cls):
  function _unwrap_ray_remote (line 414) | def _unwrap_ray_remote(cls):
  function create_colocated_worker_cls (line 420) | def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitAr...

FILE: verl/single_controller/ray/megatron.py
  class NVMegatronRayWorkerGroup (line 25) | class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
    method __init__ (line 31) | def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: ...
  class MegatronRayWorkerGroup (line 38) | class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
    method __init__ (line 44) | def __init__(self,
    method init_megatron (line 58) | def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None):

FILE: verl/third_party/vllm/__init__.py
  function get_version (line 19) | def get_version(pkg):

FILE: verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py
  function gemma_dtensor_weight_loader (line 24) | def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function gptbigcode_dtensor_load_weights (line 61) | def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn....
  function starcoder2_dtensor_load_weights (line 76) | def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn....
  function llama_dtensor_weight_loader (line 107) | def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function qwen2_dtensor_weight_loader (line 151) | def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function qwen2vl_dtensor_weight_loader (line 188) | def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Mo...
  function deepseekv2_dtensor_weight_loader (line 237) | def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn...
  function gpt2_dtensor_weight_loader (line 317) | def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modul...
  function redistribute_dtensor (line 321) | def redistribute_dtensor(param_name: str, loaded_weights: DTensor, paral...
  function _process_parameter_names (line 335) | def _process_parameter_names(name):
  function load_dtensor_weights (line 373) | def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
  function _get_model_weight_loader (line 381) | def _get_model_weight_loader(arch: str):
  function update_dtensor_weight_loader (line 389) | def update_dtensor_weight_loader():

FILE: verl/third_party/vllm/vllm_v_0_3_1/arg_utils.py
  class EngineArgs (line 27) | class EngineArgs:
    method add_cli_args (line 61) | def add_cli_args(parser: argparse.ArgumentParser) -> argparse.Argument...
    method from_cli_args (line 178) | def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
    method create_engine_configs (line 185) | def create_engine_configs(
  class AsyncEngineArgs (line 208) | class AsyncEngineArgs(EngineArgs):
    method add_cli_args (line 215) | def add_cli_args(parser: argparse.ArgumentParser) -> argparse.Argument...

FILE: verl/third_party/vllm/vllm_v_0_3_1/config.py
  class ModelConfig (line 31) | class ModelConfig:
    method __init__ (line 75) | def __init__(
    method _verify_load_format (line 109) | def _verify_load_format(self) -> None:
    method _verify_quantization (line 124) | def _verify_quantization(self) -> None:
    method _verify_cuda_graph (line 153) | def _verify_cuda_graph(self) -> None:
    method verify_with_parallel_config (line 163) | def verify_with_parallel_config(
    method get_sliding_window (line 181) | def get_sliding_window(self) -> Optional[int]:
    method get_vocab_size (line 184) | def get_vocab_size(self) -> int:
    method get_hidden_size (line 187) | def get_hidden_size(self) -> int:
    method get_head_size (line 190) | def get_head_size(self) -> int:
    method get_total_num_kv_heads (line 194) | def get_total_num_kv_heads(self) -> int:
    method get_num_kv_heads (line 226) | def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
    method get_num_layers (line 235) | def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  class CacheConfig (line 240) | class CacheConfig:
    method __init__ (line 251) | def __init__(
    method _verify_args (line 271) | def _verify_args(self) -> None:
    method _verify_cache_dtype (line 276) | def _verify_cache_dtype(self) -> None:
    method verify_with_parallel_config (line 294) | def verify_with_parallel_config(
  class ParallelConfig (line 313) | class ParallelConfig:
    method __init__ (line 329) | def __init__(
    method _verify_args (line 348) | def _verify_args(self) -> None:
  class SchedulerConfig (line 370) | class SchedulerConfig:
    method __init__ (line 383) | def __init__(
    method _verify_args (line 401) | def _verify_args(self) -> None:
  class DeviceConfig (line 415) | class DeviceConfig:
    method __init__ (line 417) | def __init__(self, device: str = "cuda") -> None:
  class LoRAConfig (line 422) | class LoRAConfig:
    method __post_init__ (line 431) | def __post_init__(self):
    method verify_with_model_config (line 449) | def verify_with_model_config(self, model_config: ModelConfig):
    method verify_with_scheduler_config (line 457) | def verify_with_scheduler_config(self, scheduler_config: SchedulerConf...
  function _get_and_verify_dtype (line 475) | def _get_and_verify_dtype(
  function _get_and_verify_max_len (line 525) | def _get_and_verify_max_len(

FILE: verl/third_party/vllm/vllm_v_0_3_1/llm.py
  class LLM (line 33) | class LLM:
    method __init__ (line 87) | def __init__(
    method init_cache_engine (line 133) | def init_cache_engine(self):
    method free_cache_engine (line 136) | def free_cache_engine(self):
    method get_tokenizer (line 139) | def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokeni...
    method set_tokenizer (line 142) | def set_tokenizer(
    method generate (line 148) | def generate(
    method _add_request (line 201) | def _add_request(
    method _run_engine (line 217) | def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
    method _pre_process_inputs (line 242) | def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[...
    method _post_process_outputs (line 250) | def _post_process_outputs(self, outputs: List[RequestOutput]) -> Tuple...
    method sync_model_weights (line 271) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor]) -...
    method offload_model_weights (line 274) | def offload_model_weights(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_3_1/llm_engine_sp.py
  class LLMEngine (line 41) | class LLMEngine:
    method __init__ (line 70) | def __init__(
    method _init_tokenizer (line 138) | def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
    method get_tokenizer_for_seq (line 146) | def get_tokenizer_for_seq(self, sequence: Sequence):
    method _init_workers_sp (line 149) | def _init_workers_sp(self, model, distributed_init_method: str):
    method _verify_args (line 172) | def _verify_args(self) -> None:
    method _init_cache_sp (line 176) | def _init_cache_sp(self) -> None:
    method init_cache_engine (line 215) | def init_cache_engine(self):
    method free_cache_engine (line 218) | def free_cache_engine(self):
    method from_engine_args (line 222) | def from_engine_args(cls, model, tokenizer, engine_args: EngineArgs) -...
    method add_request (line 238) | def add_request(
    method abort_request (line 317) | def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
    method get_model_config (line 336) | def get_model_config(self) -> ModelConfig:
    method get_num_unfinished_requests (line 340) | def get_num_unfinished_requests(self) -> int:
    method has_unfinished_requests (line 344) | def has_unfinished_requests(self) -> bool:
    method _check_beam_search_early_stopping (line 348) | def _check_beam_search_early_stopping(
    method _process_sequence_group_outputs (line 385) | def _process_sequence_group_outputs(self, seq_group: SequenceGroup, ou...
    method _process_model_outputs (line 545) | def _process_model_outputs(self, output: SamplerOutput, scheduler_outp...
    method step (line 574) | def step(self) -> List[RequestOutput]:
    method do_log_stats (line 595) | def do_log_stats(self) -> None:
    method _get_stats (line 600) | def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs]) ->...
    method _decode_sequence (line 662) | def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
    method _check_stop (line 681) | def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) ...
    method _finalize_sequence (line 710) | def _finalize_sequence(self, seq: Sequence, sampling_params: SamplingP...
    method add_lora (line 716) | def add_lora(self, lora_request: LoRARequest) -> bool:
    method remove_lora (line 720) | def remove_lora(self, lora_id: int) -> bool:
    method list_loras (line 724) | def list_loras(self) -> List[int]:
    method sync_model_weights (line 727) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor]) -...
    method offload_model_weights (line 730) | def offload_model_weights(self) -> None:
  function initialize_cluster (line 734) | def initialize_cluster(
  function get_open_port (line 762) | def get_open_port():

FILE: verl/third_party/vllm/vllm_v_0_3_1/model_loader.py
  function _set_default_torch_dtype (line 38) | def _set_default_torch_dtype(dtype: torch.dtype):
  function _get_model_architecture (line 46) | def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
  function vocab_init (line 87) | def vocab_init(self,
  function _get_model_weight_loader (line 123) | def _get_model_weight_loader(arch: str):
  function get_model (line 130) | def get_model(actor_model: Union[PreTrainedModel, Dict],
  function load_weights (line 181) | def load_weights(actor_weights: Dict, vllm_model: nn.Module):
  function _get_logits (line 193) | def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  function forward (line 206) | def forward(

FILE: verl/third_party/vllm/vllm_v_0_3_1/model_runner.py
  class ModelRunner (line 46) | class ModelRunner(ModelRunner):
    method __init__ (line 48) | def __init__(
    method load_model (line 90) | def load_model(self) -> None:
    method _prepare_sample (line 109) | def _prepare_sample(
    method prepare_input_tensors (line 174) | def prepare_input_tensors(
    method execute_model (line 203) | def execute_model(
    method profile_run (line 235) | def profile_run(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_3_1/parallel_state.py
  function initialize_model_parallel_from_megatron (line 26) | def initialize_model_parallel_from_megatron(
  function get_tensor_model_parallel_group (line 108) | def get_tensor_model_parallel_group():
  function get_tensor_model_parallel_world_size (line 114) | def get_tensor_model_parallel_world_size():
  function get_tensor_model_parallel_rank (line 119) | def get_tensor_model_parallel_rank():
  function get_tensor_model_parallel_src_rank (line 124) | def get_tensor_model_parallel_src_rank():
  function get_micro_data_parallel_group (line 137) | def get_micro_data_parallel_group():
  function get_micro_data_parallel_world_size (line 142) | def get_micro_data_parallel_world_size():
  function get_micro_data_parallel_rank (line 146) | def get_micro_data_parallel_rank():

FILE: verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py
  class TokenizerGroup (line 25) | class TokenizerGroup:
    method __init__ (line 28) | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, ...
    method encode (line 38) | def encode(self,
    method encode_async (line 45) | async def encode_async(self,
    method get_lora_tokenizer (line 52) | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "...
    method pad_token_id (line 67) | def pad_token_id(self):
    method eos_token_id (line 71) | def eos_token_id(self):

FILE: verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py
  function parallel_weight_loader (line 22) | def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: tor...
  function default_weight_loader (line 32) | def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tens...
  function gpt2_weight_loader (line 40) | def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn...
  function llama_weight_loader (line 68) | def llama_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> n...
  function mistral_weight_loader (line 83) | def mistral_weight_loader(actor_weights: Dict, vllm_model: nn.Module) ->...

FILE: verl/third_party/vllm/vllm_v_0_3_1/worker.py
  class Worker (line 39) | class Worker:
    method __init__ (line 47) | def __init__(
    method init_model (line 89) | def init_model(self, cupy_port: Optional[int] = None):
    method load_model (line 117) | def load_model(self):
    method profile_num_available_blocks (line 121) | def profile_num_available_blocks(
    method init_cache_engine (line 168) | def init_cache_engine(self, cache_config: CacheConfig) -> None:
    method free_cache_engine (line 176) | def free_cache_engine(self):
    method warm_up_model (line 181) | def warm_up_model(self) -> None:
    method cache_swap (line 188) | def cache_swap(
    method execute_model (line 215) | def execute_model(
    method sync_model_weights (line 247) | def sync_model_weights(self, actor_weights: Dict):
    method offload_model_weights (line 250) | def offload_model_weights(self) -> None:
    method add_lora (line 260) | def add_lora(self, lora_request: LoRARequest) -> bool:
    method remove_lora (line 263) | def remove_lora(self, lora_id: int) -> bool:
    method list_loras (line 266) | def list_loras(self) -> Set[int]:
  function _init_distributed_environment (line 270) | def _init_distributed_environment(
  function _pad_to_alignment (line 298) | def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[...
  function _pad_to_max (line 302) | def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
  function _check_if_gpu_supports_dtype (line 306) | def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):

FILE: verl/third_party/vllm/vllm_v_0_4_2/arg_utils.py
  function nullable_str (line 33) | def nullable_str(val: str):
  class EngineArgs (line 40) | class EngineArgs:
    method add_cli_args (line 106) | def add_cli_args(parser: argparse.ArgumentParser) -> argparse.Argument...
    method from_cli_args (line 223) | def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
    method create_engine_config (line 230) | def create_engine_config(

FILE: verl/third_party/vllm/vllm_v_0_4_2/config.py
  class ModelConfig (line 37) | class ModelConfig(ModelConfig):
    method __init__ (line 98) | def __init__(
  class LoadFormat (line 147) | class LoadFormat(str, enum.Enum):
  class LoadConfig (line 158) | class LoadConfig:
    method __post_init__ (line 180) | def __post_init__(self):
    method _verify_load_format (line 186) | def _verify_load_format(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_4_2/dtensor_weight_loaders.py
  function gemma_dtensor_weight_loader (line 26) | def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function gptbigcode_dtensor_load_weights (line 74) | def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn....
  function starcoder2_dtensor_load_weights (line 89) | def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn....
  function llama_dtensor_weight_loader (line 120) | def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function qwen2_dtensor_weight_loader (line 164) | def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function gpt2_dtensor_weight_loader (line 201) | def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modul...
  function redistribute_dtensor (line 205) | def redistribute_dtensor(param_name: str, loaded_weights: DTensor, paral...
  function _process_parameter_names (line 218) | def _process_parameter_names(name):
  function load_dtensor_weights (line 252) | def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
  function _get_model_weight_loader (line 260) | def _get_model_weight_loader(arch: str):
  function update_dtensor_weight_loader (line 268) | def update_dtensor_weight_loader():

FILE: verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py
  function update_hf_weight_loader (line 25) | def update_hf_weight_loader():
  function gemma_load_weights (line 30) | def gemma_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  function load_hf_weights (line 79) | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):

FILE: verl/third_party/vllm/vllm_v_0_4_2/llm.py
  class LLM (line 35) | class LLM:
    method __init__ (line 89) | def __init__(
    method init_cache_engine (line 137) | def init_cache_engine(self):
    method free_cache_engine (line 140) | def free_cache_engine(self):
    method get_tokenizer (line 143) | def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokeni...
    method set_tokenizer (line 146) | def set_tokenizer(
    method generate (line 152) | def generate(
    method _add_request (line 232) | def _add_request(
    method _run_engine (line 248) | def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
    method _pre_process_inputs (line 273) | def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[...
    method _post_process_outputs (line 281) | def _post_process_outputs(self, request_outputs: List[RequestOutput]) ...
    method sync_model_weights (line 302) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], l...
    method offload_model_weights (line 305) | def offload_model_weights(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_4_2/llm_engine_sp.py
  class LLMEngine (line 43) | class LLMEngine(LLMEngine):
    method __init__ (line 74) | def __init__(
    method _init_tokenizer (line 229) | def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
    method init_cache_engine (line 236) | def init_cache_engine(self):
    method free_cache_engine (line 241) | def free_cache_engine(self):
    method from_engine_args (line 247) | def from_engine_args(
    method sync_model_weights (line 279) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], l...
    method offload_model_weights (line 282) | def offload_model_weights(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py
  function parallel_weight_loader (line 27) | def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: tor...
  function default_weight_loader (line 37) | def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tens...
  function gpt2_weight_loader (line 45) | def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn...
  function llama_megatron_weight_loader (line 73) | def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Mod...
  function llama_megatron_core_te_weight_loader (line 85) | def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model...
  function llama_megatron_core_weight_loader (line 116) | def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: n...
  function _replace_name (line 146) | def _replace_name(megatron_name, name_mapping):
  function llama_megatron_core_te_weight_loader (line 169) | def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model...
  function llama_megatron_core_weight_loader (line 200) | def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: n...
  function _replace_name (line 230) | def _replace_name(megatron_name, name_mapping):
  function mistral_megatron_weight_loader (line 253) | def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.M...
  function load_megatron_weights (line 290) | def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module):
  function _get_model_weight_loader (line 298) | def _get_model_weight_loader(arch: str):
  function update_megatron_weight_loader (line 305) | def update_megatron_weight_loader():
  function vocab_init (line 316) | def vocab_init(self,

FILE: verl/third_party/vllm/vllm_v_0_4_2/model_loader.py
  function get_model (line 34) | def get_model(actor_model: Union[PreTrainedModel, Dict], model_config: M...
  function get_model_loader (line 55) | def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  class DummyModelLoader (line 94) | class DummyModelLoader(BaseModelLoader):
    method __init__ (line 97) | def __init__(self, load_config: LoadConfig):
    method load_model (line 103) | def load_model(self, *, model_config: ModelConfig, device_config: Devi...
  class MegatronLoader (line 115) | class MegatronLoader(BaseModelLoader):
    method __init__ (line 118) | def __init__(self, load_config: LoadConfig):
    method _get_weights_iterator (line 124) | def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
    method load_model (line 133) | def load_model(self, actor_model: Union[PreTrainedModel,
  class HFLoader (line 161) | class HFLoader(BaseModelLoader):
    method __init__ (line 164) | def __init__(self, load_config: LoadConfig):
    method _get_weights_iterator (line 170) | def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Di...
    method load_model (line 178) | def load_model(self, actor_model: Union[PreTrainedModel,
  class DTensorLoader (line 200) | class DTensorLoader(BaseModelLoader):
    method __init__ (line 203) | def __init__(self, load_config: LoadConfig):
    method _get_weights_iterator (line 209) | def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
    method load_model (line 218) | def load_model(self, actor_model: Union[PreTrainedModel,
  function _get_logits (line 250) | def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,

FILE: verl/third_party/vllm/vllm_v_0_4_2/model_runner.py
  class BatchType (line 39) | class BatchType(IntEnum):
  class ModelRunner (line 48) | class ModelRunner(ModelRunner):
    method __init__ (line 50) | def __init__(
    method load_model (line 105) | def load_model(self) -> None:
    method prepare_input_tensors (line 147) | def prepare_input_tensors(
    method execute_model (line 238) | def execute_model(

FILE: verl/third_party/vllm/vllm_v_0_4_2/parallel_state.py
  function initialize_parallel_state (line 35) | def initialize_parallel_state(
  function ensure_model_parallel_initialized (line 66) | def ensure_model_parallel_initialized(
  function model_parallel_is_initialized (line 92) | def model_parallel_is_initialized():
  function initialize_model_parallel_for_vllm (line 98) | def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int,
  function initialize_model_parallel (line 172) | def initialize_model_parallel(
  function get_device_mesh (line 263) | def get_device_mesh():
  function get_tensor_model_parallel_group (line 273) | def get_tensor_model_parallel_group():
  function get_tensor_model_parallel_world_size (line 279) | def get_tensor_model_parallel_world_size():
  function get_tensor_model_parallel_rank (line 284) | def get_tensor_model_parallel_rank():
  function get_tensor_model_parallel_src_rank (line 289) | def get_tensor_model_parallel_src_rank():

FILE: verl/third_party/vllm/vllm_v_0_4_2/spmd_gpu_executor.py
  class SPMDGPUExecutor (line 33) | class SPMDGPUExecutor(ExecutorBase):
    method __init__ (line 36) | def __init__(
    method _init_executor (line 63) | def _init_executor(self, model, distributed_init_method) -> None:
    method _init_workers_sp (line 69) | def _init_workers_sp(self, model, distributed_init_method: str):
    method determine_num_available_blocks (line 97) | def determine_num_available_blocks(self) -> Tuple[int, int]:
    method initialize_cache (line 117) | def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -...
    method init_cache_engine (line 140) | def init_cache_engine(self) -> None:
    method free_cache_engine (line 143) | def free_cache_engine(self) -> None:
    method execute_model (line 146) | def execute_model(self, execute_model_req) -> List[SamplerOutput]:
    method add_lora (line 154) | def add_lora(self, lora_request: LoRARequest) -> bool:
    method remove_lora (line 158) | def remove_lora(self, lora_id: int) -> bool:
    method list_loras (line 162) | def list_loras(self) -> Set[int]:
    method check_health (line 165) | def check_health(self) -> None:
    method offload_model_weights (line 171) | def offload_model_weights(self) -> None:
    method sync_model_weights (line 174) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], l...
  function initialize_cluster (line 178) | def initialize_cluster(
  function get_open_port (line 202) | def get_open_port():
  class SPMDGPUExecutorAsync (line 209) | class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase):
    method execute_model_async (line 211) | async def execute_model_async(self, execute_model_req: ExecuteModelReq...
    method check_health_async (line 215) | async def check_health_async(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py
  class TokenizerGroup (line 25) | class TokenizerGroup:
    method __init__ (line 28) | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, ...
    method ping (line 35) | def ping(self) -> bool:
    method get_max_input_len (line 39) | def get_max_input_len(self, lora_request: Optional[LoRARequest] = None...
    method encode (line 43) | def encode(self,
    method encode_async (line 50) | async def encode_async(self,
    method get_lora_tokenizer (line 57) | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "...
    method pad_token_id (line 72) | def pad_token_id(self):
    method eos_token_id (line 76) | def eos_token_id(self):

FILE: verl/third_party/vllm/vllm_v_0_4_2/worker.py
  class Worker (line 42) | class Worker(Worker):
    method __init__ (line 50) | def __init__(
    method init_device (line 105) | def init_device(self) -> None:
    method determine_num_available_blocks (line 142) | def determine_num_available_blocks(self) -> Tuple[int, int]:
    method _init_cache_engine (line 199) | def _init_cache_engine(self):
    method free_cache_engine (line 203) | def free_cache_engine(self):
    method execute_model (line 209) | def execute_model(self, execute_model_req: Optional[ExecuteModelReques...
    method sync_model_weights (line 237) | def sync_model_weights(self, actor_weights: Dict, load_format: str):
    method offload_model_weights (line 246) | def offload_model_weights(self) -> None:
  function init_worker_distributed_environment (line 257) | def init_worker_distributed_environment(

FILE: verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py
  function nullable_str (line 43) | def nullable_str(val: str):
  class EngineArgs (line 50) | class EngineArgs:
    method add_cli_args (line 140) | def add_cli_args(parser: argparse.ArgumentParser) -> argparse.Argument...
    method from_cli_args (line 257) | def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
    method create_engine_config (line 264) | def create_engine_config(

FILE: verl/third_party/vllm/vllm_v_0_5_4/config.py
  class ModelConfig (line 38) | class ModelConfig(ModelConfig):
    method __init__ (line 99) | def __init__(
  class LoadFormat (line 181) | class LoadFormat(str, enum.Enum):
  class LoadConfig (line 193) | class LoadConfig:
    method __post_init__ (line 221) | def __post_init__(self):
    method _verify_load_format (line 232) | def _verify_load_format(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py
  function gemma_dtensor_weight_loader (line 27) | def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function gptbigcode_dtensor_load_weights (line 64) | def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn....
  function starcoder2_dtensor_load_weights (line 79) | def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn....
  function llama_dtensor_weight_loader (line 110) | def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function qwen2_dtensor_weight_loader (line 154) | def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function deepseekv2_dtensor_weight_loader (line 194) | def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn...
  function gpt2_dtensor_weight_loader (line 270) | def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modul...
  function redistribute_dtensor (line 274) | def redistribute_dtensor(param_name: str, loaded_weights: DTensor, paral...
  function _process_parameter_names (line 287) | def _process_parameter_names(name):
  function load_dtensor_weights (line 323) | def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
  function _get_model_weight_loader (line 331) | def _get_model_weight_loader(arch: str):
  function update_dtensor_weight_loader (line 339) | def update_dtensor_weight_loader():

FILE: verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py
  function update_hf_weight_loader (line 25) | def update_hf_weight_loader():
  function load_hf_weights (line 30) | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):

FILE: verl/third_party/vllm/vllm_v_0_5_4/llm.py
  class LLM (line 43) | class LLM(LLM):
    method __init__ (line 97) | def __init__(
    method init_cache_engine (line 151) | def init_cache_engine(self):
    method free_cache_engine (line 154) | def free_cache_engine(self):
    method get_tokenizer (line 157) | def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokeni...
    method set_tokenizer (line 160) | def set_tokenizer(
    method _run_engine (line 166) | def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, ...
    method _post_process_outputs (line 214) | def _post_process_outputs(self, request_outputs: List[RequestOutput]) ...
    method sync_model_weights (line 235) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], l...
    method offload_model_weights (line 238) | def offload_model_weights(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py
  class LLMEngine (line 46) | class LLMEngine(LLMEngine):
    method __init__ (line 77) | def __init__(
    method _init_tokenizer (line 263) | def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
    method init_cache_engine (line 270) | def init_cache_engine(self):
    method free_cache_engine (line 275) | def free_cache_engine(self):
    method _get_executor_cls (line 281) | def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[Execut...
    method from_engine_args (line 293) | def from_engine_args(
    method sync_model_weights (line 324) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], l...
    method offload_model_weights (line 327) | def offload_model_weights(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py
  function parallel_weight_loader (line 27) | def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: tor...
  function default_weight_loader (line 37) | def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tens...
  function gpt2_weight_loader (line 45) | def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn...
  function llama_megatron_weight_loader (line 73) | def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Mod...
  function llama_megatron_core_te_weight_loader (line 85) | def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model...
  function llama_megatron_core_weight_loader (line 116) | def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: n...
  function _replace_name (line 146) | def _replace_name(megatron_name, name_mapping):
  function llama_megatron_core_te_weight_loader (line 169) | def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model...
  function llama_megatron_core_weight_loader (line 200) | def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: n...
  function _replace_name (line 230) | def _replace_name(megatron_name, name_mapping):
  function mistral_megatron_weight_loader (line 253) | def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.M...
  function load_megatron_weights (line 290) | def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module):
  function _get_model_weight_loader (line 298) | def _get_model_weight_loader(arch: str):
  function update_megatron_weight_loader (line 305) | def update_megatron_weight_loader():

FILE: verl/third_party/vllm/vllm_v_0_5_4/model_loader.py
  function get_model (line 35) | def get_model(actor_model: Union[PreTrainedModel, Dict],
  function get_model_loader (line 64) | def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  class DummyModelLoader (line 103) | class DummyModelLoader(BaseModelLoader):
    method __init__ (line 106) | def __init__(self, load_config: LoadConfig):
    method load_model (line 112) | def load_model(self, *, model_config: ModelConfig, device_config: Devi...
  class MegatronLoader (line 125) | class MegatronLoader(BaseModelLoader):
    method __init__ (line 128) | def __init__(self, load_config: LoadConfig):
    method _get_weights_iterator (line 134) | def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
    method load_model (line 143) | def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_...
  class HFLoader (line 172) | class HFLoader(BaseModelLoader):
    method __init__ (line 175) | def __init__(self, load_config: LoadConfig):
    method _get_weights_iterator (line 181) | def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Di...
    method load_model (line 189) | def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_...
  class DTensorLoader (line 212) | class DTensorLoader(BaseModelLoader):
    method __init__ (line 215) | def __init__(self, load_config: LoadConfig):
    method _get_weights_iterator (line 221) | def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
    method load_model (line 230) | def load_model(self, actor_model: Union[PreTrainedModel, Dict], model_...
  function _get_logits (line 263) | def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  function logitsprocessor_init (line 279) | def logitsprocessor_init(self,

FILE: verl/third_party/vllm/vllm_v_0_5_4/model_runner.py
  class BatchType (line 43) | class BatchType(IntEnum):
  class ModelRunner (line 52) | class ModelRunner(ModelRunner):
    method __init__ (line 54) | def __init__(
    method load_model (line 89) | def load_model(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py
  function initialize_parallel_state (line 37) | def initialize_parallel_state(
  function ensure_model_parallel_initialized (line 68) | def ensure_model_parallel_initialized(
  function model_parallel_is_initialized (line 95) | def model_parallel_is_initialized():
  function initialize_model_parallel_for_vllm (line 101) | def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int,
  function initialize_model_parallel (line 191) | def initialize_model_parallel(
  function get_device_mesh (line 272) | def get_device_mesh():
  function get_tensor_model_parallel_group (line 282) | def get_tensor_model_parallel_group():
  function get_tensor_model_parallel_world_size (line 288) | def get_tensor_model_parallel_world_size():
  function get_tensor_model_parallel_rank (line 293) | def get_tensor_model_parallel_rank():
  function get_tensor_model_parallel_src_rank (line 298) | def get_tensor_model_parallel_src_rank():

FILE: verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py
  class SPMDGPUExecutor (line 34) | class SPMDGPUExecutor(ExecutorBase):
    method __init__ (line 37) | def __init__(
    method _init_executor (line 66) | def _init_executor(self, model, distributed_init_method) -> None:
    method _init_workers_sp (line 72) | def _init_workers_sp(self, model, distributed_init_method: str):
    method determine_num_available_blocks (line 107) | def determine_num_available_blocks(self) -> Tuple[int, int]:
    method initialize_cache (line 127) | def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -...
    method init_cache_engine (line 150) | def init_cache_engine(self) -> None:
    method free_cache_engine (line 153) | def free_cache_engine(self) -> None:
    method execute_model (line 156) | def execute_model(self, execute_model_req) -> List[SamplerOutput]:
    method add_lora (line 164) | def add_lora(self, lora_request: LoRARequest) -> bool:
    method remove_lora (line 168) | def remove_lora(self, lora_id: int) -> bool:
    method list_loras (line 172) | def list_loras(self) -> Set[int]:
    method check_health (line 175) | def check_health(self) -> None:
    method add_prompt_adapter (line 183) | def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequ...
    method list_prompt_adapters (line 188) | def list_prompt_adapters(self) -> Set[int]:
    method pin_lora (line 191) | def pin_lora(self, lora_id: int) -> bool:
    method pin_prompt_adapter (line 195) | def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
    method remove_prompt_adapter (line 200) | def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
    method offload_model_weights (line 206) | def offload_model_weights(self) -> None:
    method sync_model_weights (line 209) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], l...
  function initialize_cluster (line 213) | def initialize_cluster(
  function get_open_port (line 237) | def get_open_port():
  class SPMDGPUExecutorAsync (line 244) | class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase):
    method execute_model_async (line 246) | async def execute_model_async(self, execute_model_req: ExecuteModelReq...
    method check_health_async (line 250) | async def check_health_async(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py
  class TokenizerGroup (line 25) | class TokenizerGroup:
    method __init__ (line 28) | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, ...
    method ping (line 35) | def ping(self) -> bool:
    method get_max_input_len (line 39) | def get_max_input_len(self, lora_request: Optional[LoRARequest] = None...
    method encode (line 43) | def encode(self,
    method encode_async (line 50) | async def encode_async(self,
    method get_lora_tokenizer (line 57) | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "...
    method pad_token_id (line 72) | def pad_token_id(self):
    method eos_token_id (line 76) | def eos_token_id(self):

FILE: verl/third_party/vllm/vllm_v_0_5_4/worker.py
  class Worker (line 44) | class Worker(Worker):
    method __init__ (line 52) | def __init__(
    method init_device (line 134) | def init_device(self) -> None:
    method determine_num_available_blocks (line 171) | def determine_num_available_blocks(self) -> Tuple[int, int]:
    method _init_cache_engine (line 229) | def _init_cache_engine(self):
    method free_cache_engine (line 233) | def free_cache_engine(self):
    method execute_model (line 239) | def execute_model(self,
    method sync_model_weights (line 266) | def sync_model_weights(self, actor_weights: Dict, load_format: str):
    method offload_model_weights (line 275) | def offload_model_weights(self) -> None:
  function init_worker_distributed_environment (line 286) | def init_worker_distributed_environment(

FILE: verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py
  class EngineArgs (line 27) | class EngineArgs(EngineArgs):
    method __post_init__ (line 30) | def __post_init__(self):
    method create_model_config (line 33) | def create_model_config(self) -> ModelConfig:
    method create_load_config (line 62) | def create_load_config(self) -> LoadConfig:
    method create_engine_config (line 70) | def create_engine_config(self) -> EngineConfig:

FILE: verl/third_party/vllm/vllm_v_0_6_3/config.py
  class LoadFormat (line 34) | class LoadFormat(str, enum.Enum):
  class ModelConfig (line 44) | class ModelConfig(ModelConfig):
    method __init__ (line 46) | def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None:
  class LoadConfig (line 52) | class LoadConfig:
    method __post_init__ (line 80) | def __post_init__(self):
    method _verify_load_format (line 91) | def _verify_load_format(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py
  function gemma_dtensor_weight_loader (line 24) | def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function gptbigcode_dtensor_load_weights (line 61) | def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn....
  function starcoder2_dtensor_load_weights (line 76) | def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn....
  function llama_dtensor_weight_loader (line 107) | def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function qwen2_dtensor_weight_loader (line 151) | def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modu...
  function qwen2vl_dtensor_weight_loader (line 188) | def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Mo...
  function deepseekv2_dtensor_weight_loader (line 228) | def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn...
  function gpt2_dtensor_weight_loader (line 308) | def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Modul...
  function redistribute_dtensor (line 312) | def redistribute_dtensor(param_name: str, loaded_weights: DTensor, paral...
  function _process_parameter_names (line 326) | def _process_parameter_names(name):
  function load_dtensor_weights (line 363) | def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
  function _get_model_weight_loader (line 371) | def _get_model_weight_loader(arch: str):
  function update_dtensor_weight_loader (line 379) | def update_dtensor_weight_loader():

FILE: verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py
  function update_hf_weight_loader (line 22) | def update_hf_weight_loader():
  function load_hf_weights (line 27) | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):

FILE: verl/third_party/vllm/vllm_v_0_6_3/llm.py
  class LLM (line 31) | class LLM(LLM):
    method __init__ (line 85) | def __init__(
    method init_cache_engine (line 145) | def init_cache_engine(self):
    method free_cache_engine (line 148) | def free_cache_engine(self):
    method get_tokenizer (line 151) | def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokeni...
    method set_tokenizer (line 154) | def set_tokenizer(
    method _run_engine (line 160) | def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, ...
    method _post_process_outputs (line 174) | def _post_process_outputs(self, request_outputs: List[RequestOutput]) ...
    method sync_model_weights (line 196) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], l...
    method offload_model_weights (line 199) | def offload_model_weights(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py
  class LLMEngine (line 61) | class LLMEngine(LLMEngine):
    method __init__ (line 95) | def __init__(
    method _init_tokenizer (line 337) | def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
    method init_cache_engine (line 344) | def init_cache_engine(self):
    method free_cache_engine (line 349) | def free_cache_engine(self):
    method _get_executor_cls (line 355) | def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[Execut...
    method from_engine_args (line 372) | def from_engine_args(
    method sync_model_weights (line 404) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], l...
    method offload_model_weights (line 407) | def offload_model_weights(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py
  function parallel_weight_loader (line 26) | def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: tor...
  function default_weight_loader (line 37) | def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tens...
  function gpt2_weight_loader (line 46) | def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn...
  function llama_megatron_weight_loader (line 74) | def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Mod...
  function qwen2_megatron_weight_loader (line 86) | def qwen2_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Mod...
  function llama_megatron_core_te_weight_loader (line 98) | def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model...
  function llama_megatron_core_weight_loader (line 129) | def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: n...
  function _replace_name (line 159) | def _replace_name(megatron_name, name_mapping):
  function llama_megatron_core_te_weight_loader (line 182) | def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model...
  function llama_megatron_core_weight_loader (line 213) | def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: n...
  function _replace_name (line 243) | def _replace_name(megatron_name, name_mapping):
  function mistral_megatron_weight_loader (line 266) | def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.M...
  function load_megatron_weights (line 304) | def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module):
  function _get_model_weight_loader (line 312) | def _get_model_weight_loader(arch: str):
  function update_megatron_weight_loader (line 319) | def update_megatron_weight_loader():

FILE: verl/third_party/vllm/vllm_v_0_6_3/model_loader.py
  function get_model (line 33) | def get_model(
  function get_model_loader (line 65) | def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  class DummyModelLoader (line 104) | class DummyModelLoader(BaseModelLoader):
    method __init__ (line 107) | def __init__(self, load_config: LoadConfig):
    method download_model (line 113) | def download_model(self, model_config: ModelConfig) -> None:
    method load_model (line 116) | def load_model(
  class MegatronLoader (line 135) | class MegatronLoader(BaseModelLoader):
    method __init__ (line 138) | def __init__(self, load_config: LoadConfig):
    method download_model (line 144) | def download_model(self, model_config: ModelConfig) -> None:
    method _get_weights_iterator (line 147) | def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
    method load_model (line 156) | def load_model(
  class HFLoader (line 190) | class HFLoader(BaseModelLoader):
    method __init__ (line 193) | def __init__(self, load_config: LoadConfig):
    method download_model (line 199) | def download_model(self, model_config: ModelConfig) -> None:
    method _get_weights_iterator (line 202) | def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Di...
    method load_model (line 210) | def load_model(
  class DTensorLoader (line 238) | class DTensorLoader(BaseModelLoader):
    method __init__ (line 241) | def __init__(self, load_config: LoadConfig):
    method download_model (line 247) | def download_model(self, model_config: ModelConfig) -> None:
    method _get_weights_iterator (line 250) | def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
    method load_model (line 259) | def load_model(
  function _get_logits (line 297) | def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  function logitsprocessor_init (line 313) | def logitsprocessor_init(

FILE: verl/third_party/vllm/vllm_v_0_6_3/model_runner.py
  class BatchType (line 51) | class BatchType(IntEnum):
  class ModelRunner (line 60) | class ModelRunner(ModelRunner):
    method __init__ (line 62) | def __init__(
    method load_model (line 101) | def load_model(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py
  function initialize_parallel_state (line 38) | def initialize_parallel_state(
  function ensure_model_parallel_initialized (line 71) | def ensure_model_parallel_initialized(
  function model_parallel_is_initialized (line 98) | def model_parallel_is_initialized():
  function initialize_model_parallel_for_vllm (line 104) | def initialize_model_parallel_for_vllm(
  function initialize_model_parallel (line 199) | def initialize_model_parallel(
  function get_device_mesh (line 281) | def get_device_mesh():
  function get_tensor_model_parallel_group (line 291) | def get_tensor_model_parallel_group():
  function get_tensor_model_parallel_world_size (line 297) | def get_tensor_model_parallel_world_size():
  function get_tensor_model_parallel_rank (line 302) | def get_tensor_model_parallel_rank():
  function get_tensor_model_parallel_src_rank (line 307) | def get_tensor_model_parallel_src_rank():

FILE: verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py
  class SPMDGPUExecutor (line 42) | class SPMDGPUExecutor(ExecutorBase):
    method __init__ (line 45) | def __init__(
    method _init_executor (line 74) | def _init_executor(self, model, distributed_init_method) -> None:
    method _init_workers_sp (line 80) | def _init_workers_sp(self, model, distributed_init_method: str):
    method determine_num_available_blocks (line 114) | def determine_num_available_blocks(self) -> Tuple[int, int]:
    method initialize_cache (line 134) | def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -...
    method init_cache_engine (line 156) | def init_cache_engine(self) -> None:
    method free_cache_engine (line 159) | def free_cache_engine(self) -> None:
    method execute_model (line 162) | def execute_model(self, execute_model_req) -> List[SamplerOutput]:
    method add_lora (line 170) | def add_lora(self, lora_request: LoRARequest) -> bool:
    method remove_lora (line 174) | def remove_lora(self, lora_id: int) -> bool:
    method list_loras (line 178) | def list_loras(self) -> Set[int]:
    method check_health (line 181) | def check_health(self) -> None:
    method add_prompt_adapter (line 189) | def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequ...
    method list_prompt_adapters (line 193) | def list_prompt_adapters(self) -> Set[int]:
    method pin_lora (line 196) | def pin_lora(self, lora_id: int) -> bool:
    method pin_prompt_adapter (line 200) | def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
    method remove_prompt_adapter (line 204) | def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
    method offload_model_weights (line 209) | def offload_model_weights(self) -> None:
    method sync_model_weights (line 212) | def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], l...
  function initialize_cluster (line 216) | def initialize_cluster(
  function get_open_port (line 240) | def get_open_port():
  class SPMDGPUExecutorAsync (line 247) | class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase):
    method execute_model_async (line 249) | async def execute_model_async(self, execute_model_req: ExecuteModelReq...
    method check_health_async (line 253) | async def check_health_async(self) -> None:

FILE: verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py
  class TokenizerGroup (line 23) | class TokenizerGroup(TokenizerGroup):
    method __init__ (line 26) | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, ...
    method pad_token_id (line 35) | def pad_token_id(self):
    method eos_token_id (line 39) | def eos_token_id(self):

FILE: verl/third_party/vllm/vllm_v_0_6_3/worker.py
  class Worker (line 53) | class Worker(Worker):
    method __init__ (line 61) | def __init__(
    method init_device (line 140) | def init_device(self) -> None:
    method determine_num_available_blocks (line 177) | def determine_num_available_blocks(self) -> Tuple[int, int]:
    method _init_cache_engine (line 235) | def _init_cache_engine(self):
    method free_cache_engine (line 239) | def free_cache_engine(self):
    method execute_model (line 245) | def execute_model(self,
    method sync_model_weights (line 274) | def sync_model_weights(self, actor_weights: Dict, load_format: str):
    method offload_model_weights (line 283) | def offload_model_weights(self) -> None:
  function init_worker_distributed_environment (line 294) | def init_worker_distributed_environment(

FILE: verl/trainer/fsdp_sft_trainer.py
  function extract_step (line 59) | def extract_step(path):
  function convert_to_regular_types (line 66) | def convert_to_regular_types(obj):
  class FSDPSFTTrainer (line 78) | class FSDPSFTTrainer(object):
    method __init__ (line 80) | def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mes...
    method _normalize_config_bsz (line 110) | def _normalize_config_bsz(self):
    method _build_dataloader (line 121) | def _build_dataloader(self):
    method _build_model_optimizer (line 181) | def _build_model_optimizer(self):
    method _compute_loss_and_backward (line 287) | def _compute_loss_and_backward(self, batch, do_backward=True):
    method training_step (line 384) | def training_step(self, batch: TensorDict):
    method validation_step (line 419) | def validation_step(self, batch: TensorDict):
    method save_checkpoint (line 426) | def save_checkpoint(self, step):
    method fit (line 444) | def fit(self):
  function main (line 520) | def main(config):

FILE: verl/trainer/main_eval.py
  function select_reward_fn (line 27) | def select_reward_fn(data_source):
  function main (line 35) | def main(config):

FILE: verl/trainer/main_generation.py
  function main (line 40) | def main(config):

FILE: verl/trainer/main_ppo.py
  function main (line 24) | def main(config):
  function run_ppo (line 28) | def run_ppo(config, compute_score=None):
  function main_task (line 37) | def main_task(config, compute_score=None):

FILE: verl/trainer/ppo/core_algos.py
  class AdaptiveKLController (line 28) | class AdaptiveKLController:
    method __init__ (line 34) | def __init__(self, init_kl_coef, target_kl, horizon):
    method update (line 39) | def update(self, current_kl, n_steps):
  class FixedKLController (line 46) | class FixedKLController:
    method __init__ (line 49) | def __init__(self, kl_coef):
    method update (line 52) | def update(self, current_kl, n_steps):
  function get_kl_controller (line 56) | def get_kl_controller(config):
  function compute_gae_advantage_return (line 70) | def compute_gae_advantage_return(token_level_rewards: torch.Tensor, valu...
  function compute_grpo_outcome_advantage (line 111) | def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
  function compute_rloo_outcome_advantage (line 157) | def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
  function compute_reinforce_plus_plus_outcome_advantage (line 202) | def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: t...
  function compute_remax_outcome_advantage (line 236) | def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, r...
  function compute_rewards (line 267) | def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_r...
  function compute_policy_loss (line 272) | def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cl...
  function compute_entropy_loss (line 306) | def compute_entropy_loss(logits, eos_mask):
  function compute_value_loss (line 325) | def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
  function kl_penalty (line 351) | def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTenso...

FILE: verl/trainer/ppo/ray_trainer.py
  class Role (line 47) | class Role(Enum):
  class AdvantageEstimator (line 60) | class AdvantageEstimator(str, Enum):
  class ResourcePoolManager (line 72) | class ResourcePoolManager:
    method create_resource_pool (line 81) | def create_resource_pool(self):
    method get_resource_pool (line 92) | def get_resource_pool(self, role: Role) -> RayResourcePool:
  function apply_kl_penalty (line 101) | def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLCont...
  function compute_advantage (line 133) | def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0...
  function reduce_metrics (line 205) | def reduce_metrics(metrics: dict):
  function _compute_response_info (line 211) | def _compute_response_info(batch):
  function compute_data_metrics (line 227) | def compute_data_metrics(batch, use_critic=True):
  function compute_timing_metrics (line 315) | def compute_timing_metrics(batch, timing_raw):
  function _timer (line 340) | def _timer(name: str, timing_raw: Dict[str, float]):
  class RayPPOTrainer (line 346) | class RayPPOTrainer(object):
    method __init__ (line 353) | def __init__(self,
    method _validate_config (line 410) | def _validate_config(self):
    method _create_dataloader (line 494) | def _create_dataloader(self):
    method _maybe_log_val_generations_to_wandb (line 562) | def _maybe_log_val_generations_to_wandb(self, inputs, outputs, scores):
    method _validate (line 612) | def _validate(self):
    method _save_samples (line 711) | def _save_samples(self, batch: DataProto, split: str):
    method init_workers (line 733) | def init_workers(self):
    method _save_checkpoint (line 800) | def _save_checkpoint(self):
    method _load_checkpoint (line 833) | def _load_checkpoint(self):
    method _balance_batch (line 886) | def _balance_batch(self, batch: DataProto, metrics, logging_prefix='gl...
    method fit (line 903) | def fit(self):

FILE: verl/utils/checkpoint/checkpoint_manager.py
  class BaseCheckpointManager (line 27) | class BaseCheckpointManager:
    method __init__ (line 42) | def __init__(self, model: FSDP, optimizer: torch.optim.Optimizer,
    method load_checkpoint (line 57) | def load_checkpoint(self, *args, **kwargs):
    method save_checkpoint (line 60) | def save_checkpoint(self, *args, **kwargs):
    method remove_previous_save_local_path (line 63) | def remove_previous_save_local_path(self):
    method local_mkdir (line 76) | def local_mkdir(path):
    method get_rng_state (line 97) | def get_rng_state():
    method load_rng_state (line 107) | def load_rng_state(rng_state):
  function find_latest_ckpt_path (line 114) | def find_latest_ckpt_path(path, directory_format="global_step_{}"):
  function get_checkpoint_tracker_filename (line 134) | def get_checkpoint_tracker_filename(root_path: str):

FILE: verl/utils/checkpoint/fsdp_checkpoint_manager.py
  class FSDPCheckpointManager (line 32) | class FSDPCheckpointManager(BaseCheckpointManager):
    method __init__ (line 47) | def __init__(self,
    method load_checkpoint (line 61) | def load_checkpoint(self, path=None, del_local_after_load=False, *args...
    method save_checkpoint (line 106) | def save_checkpoint(self, local_path: str, global_step: int, remove_pr...

FILE: verl/utils/config.py
  function update_dict_with_config (line 20) | def update_dict_with_config(dictionary: Dict, config: DictConfig):

FILE: verl/utils/dataset/rl_dataset.py
  function collate_fn (line 36) | def collate_fn(data_list: list[dict]) -> dict:
  function process_image (line 56) | def process_image(image: dict, max_pixels: int = 2048 * 2048, min_pixels...
  class RLHFDataset (line 80) | class RLHFDataset(Dataset):
    method __init__ (line 85) | def __init__(self,
    method _download (line 123) | def _download(self, use_origin_parquet=False):
    method _read_files_and_tokenize (line 129) | def _read_files_and_tokenize(self):
    method resume_dataset_state (line 153) | def resume_dataset_state(self):
    method __len__ (line 162) | def __len__(self):
    method __getitem__ (line 165) | def __getitem__(self, item):
    method __getstate__ (line 238) | def __getstate__(self):

FILE: verl/utils/dataset/rm_dataset.py
  function download_files_distributed (line 27) | def download_files_distributed(download_fn):
  class RMDataset (line 40) | class RMDataset(Dataset):
    method __init__ (line 42) | def __init__(self,
    method _download (line 70) | def _download(self):
    method _read_files_and_tokenize (line 85) | def _read_files_and_tokenize(self):
    method __len__ (line 96) | def __len__(self):
    method _pad_to_length (line 99) | def _pad_to_length(self, input_ids, attention_mask):
    method __getitem__ (line 114) | def __getitem__(self, item):

FILE: verl/utils/dataset/sft_dataset.py
  class SFTDataset (line 34) | class SFTDataset(Dataset):
    method __init__ (line 39) | def __init__(self,
    method _download (line 69) | def _download(self):
    method _read_files_and_tokenize (line 73) | def _read_files_and_tokenize(self):
    method __len__ (line 107) | def __len__(self):
    method __getitem__ (line 110) | def __getitem__(self, item):

FILE: verl/utils/debug/performance.py
  function log_gpu_memory_usage (line 20) | def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level...

FILE: verl/utils/debug/trajectory_tracker.py
  function save_to_hdfs (line 33) | def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose):
  class TrajectoryTracker (line 50) | class TrajectoryTracker():
    method __init__ (line 52) | def __init__(self, hdfs_dir, verbose) -> None:
    method dump (line 59) | def dump(self, data: io.BytesIO, name):
    method wait_for_hdfs (line 63) | def wait_for_hdfs(self):
  function dump_data (line 69) | def dump_data(data, name):
  function get_trajectory_tracker (line 79) | def get_trajectory_tracker():
  function process (line 94) | def process(iter):

FILE: verl/utils/distributed.py
  function initialize_global_process_group (line 18) | def initialize_global_process_group(timeout_second=36000):

FILE: verl/utils/flops_counter.py
  function get_device_flops (line 21) | def get_device_flops(unit="T"):
  class FlopsCounter (line 51) | class FlopsCounter:
    method __init__ (line 61) | def __init__(self, config: PretrainedConfig):
    method _estimate_unknown_flops (line 74) | def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
    method _estimate_qwen2_flops (line 77) | def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
    method estimate_flops (line 111) | def estimate_flops(self, batch_seqlens, delta_time):

FILE: verl/utils/fs.py
  function is_non_local (line 32) | def is_non_local(path):
  function md5_encode (line 36) | def md5_encode(path: str) -> str:
  function get_local_temp_path (line 40) | def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str:
  function copy_to_local (line 58) | def copy_to_local(src: str, cache_dir=None, filelock='.file.lock', verbo...
  function copy_local_path_from_hdfs (line 72) | def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock='.file....

FILE: verl/utils/fsdp_utils.py
  function init_fn (line 31) | def init_fn(x: torch.nn.Module):
  function get_init_weight_context_manager (line 38) | def get_init_weight_context_manager(use_meta_tensor=True):
  function get_fsdp_wrap_policy (line 50) | def get_fsdp_wrap_policy(module, config=None, is_lora=False):
  function offload_fsdp_model_to_cpu (line 111) | def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):
  function load_fsdp_model_to_gpu (line 132) | def load_fsdp_model_to_gpu(model: FSDP):
  function offload_fsdp_optimizer (line 148) | def offload_fsdp_optimizer(optimizer):
  function load_fsdp_optimizer (line 160) | def load_fsdp_optimizer(optimizer, device_id):
  function meta_device_init (line 172) | def meta_device_init():
  function parallel_load_safetensors (line 203) | def parallel_load_safetensors(filepath):
  function parallel_init_module_fn (line 259) | def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[...

FILE: verl/utils/hdfs_io.py
  function exists (line 27) | def exists(path: str, **kwargs) -> bool:
  function _exists (line 43) | def _exists(file_path: str):
  function makedirs (line 50) | def makedirs(name, mode=0o777, exist_ok=False, **kwargs) -> None:
  function _mkdir (line 75) | def _mkdir(file_path: str) -> bool:
  function copy (line 84) | def copy(src: str, dst: str, **kwargs) -> bool:
  function _copy (line 113) | def _copy(from_path: str, to_path: str, timeout: int = None) -> bool:
  function _run_cmd (line 135) | def _run_cmd(cmd: str, timeout=None):
  function _hdfs_cmd (line 139) | def _hdfs_cmd(cmd: str) -> str:
  function _is_non_local (line 143) | def _is_non_local(path: str):

FILE: verl/utils/import_utils.py
  function is_megatron_core_available (line 24) | def is_megatron_core_available():
  function is_vllm_available (line 33) | def is_vllm_available():
  function import_external_libs (line 41) | def import_external_libs(external_libs=None):

FILE: verl/utils/logger/aggregate_logger.py
  function concat_dict_to_str (line 21) | def concat_dict_to_str(dict: Dict, step):
  class LocalLogger (line 30) | class LocalLogger:
    method __init__ (line 32) | def __init__(self, remote_logger=None, enable_wandb=False, print_to_co...
    method flush (line 37) | def flush(self):
    method log (line 40) | def log(self, data, step):

FILE: verl/utils/logging_utils.py
  function set_basic_config (line 18) | def set_basic_config(level):

FILE: verl/utils/megatron/memory.py
  class MemoryBuffer (line 18) | class MemoryBuffer:
    method __init__ (line 20) | def __init__(self, numel, numel_padded, dtype):
    method zero (line 29) | def zero(self):
    method get (line 33) | def get(self, shape, start_index):

FILE: verl/utils/megatron/optimizer.py
  function get_megatron_optimizer (line 27) | def get_megatron_optimizer(

FILE: verl/utils/megatron/pipeline_parallel.py
  function compute_transformers_input_shapes (line 22) | def compute_transformers_input_shapes(batches, meta_info):
  function make_batch_generator (line 43) | def make_batch_generator(batches, vpp_size):

FILE: verl/utils/megatron/sequence_parallel.py
  function mark_parameter_as_sequence_parallel (line 21) | def mark_parameter_as_sequence_parallel(parameter):
  function is_sequence_parallel_param (line 25) | def is_sequence_parallel_param(param):
  function pad_to_sequence_parallel (line 29) | def pad_to_sequence_parallel(unpad_tokens: torch.Tensor):

FILE: verl/utils/megatron/tensor_parallel.py
  function update_kwargs_with_config (line 27) | def update_kwargs_with_config(dictionary: Dict, config: ModelParallelCon...
  function get_default_kwargs_for_model_parallel_config (line 32) | def get_default_kwargs_for_model_parallel_config():
  function get_default_model_parallel_config (line 43) | def get_default_model_parallel_config():
  function get_common_default_kwargs_for_parallel_linear (line 47) | def get_common_default_kwargs_for_parallel_linear():
  function get_default_kwargs_for_column_parallel_linear (line 58) | def get_default_kwargs_for_column_parallel_linear():
  function get_default_kwargs_for_row_parallel_linear (line 72) | def get_default_kwargs_for_row_parallel_linear():
  function get_default_kwargs_for_parallel_embedding (line 77) | def get_default_kwargs_for_parallel_embedding():
  function is_tensor_parallel_param (line 86) | def is_tensor_parallel_param(param):
  function get_tensor_parallel_partition_dim (line 90) | def get_tensor_parallel_partition_dim(param):
  function get_tensor_parallel_partition_stride (line 95) | def get_tensor_parallel_partition_stride(param):
  class _VocabParallelEntropy (line 100) | class _VocabParallelEntropy(torch.autograd.Function):
    method forward (line 103) | def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:
    method backward (line 118) | def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
  function vocab_parallel_entropy (line 124) | def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch...
  function vocab_parallel_log_probs_from_logits (line 136) | def vocab_parallel_log_probs_from_logits(logits, labels):
  function vocab_parallel_log_probs_from_logits_response_rmpad (line 141) | def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, atten...
  function vocab_parallel_compute_entropy_loss (line 168) | def vocab_parallel_compute_entropy_loss(logits, eos_mask):

FILE: verl/utils/megatron_utils.py
  function get_model_config (line 38) | def get_model_config(model):
  function get_model (line 42) | def get_model(model_provider_func, model_type=ModelType.encoder_or_decod...
  function unwrap_model (line 136) | def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
  function convert_config (line 154) | def convert_config(hf_config: PretrainedConfig, megatron_config) -> Tran...
  function init_megatron_optim_config (line 195) | def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig:
  function init_model_parallel_config (line 208) | def init_model_parallel_config(config: DictConfig) -> ModelParallelConfig:
  function offload_megatron_param_and_grad (line 222) | def offload_megatron_param_and_grad(module_list: nn.ModuleList, offload_...
  function load_megatron_param_and_grad (line 237) | def load_megatron_param_and_grad(module_list: nn.ModuleList, device_id, ...

FILE: verl/utils/memory_buffer.py
  class MemoryBuffer (line 24) | class MemoryBuffer:
    method __init__ (line 30) | def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, ...
    method zero (line 39) | def zero(self):
    method get (line 43) | def get(self, shape, start_index):
  function calc_padded_numel (line 54) | def calc_padded_numel(shape: torch.Size, dtype: torch.dtype):
  function get_weight_buffer_meta_from_module (line 61) | def get_weight_buffer_meta_from_module(module: nn.Module) -> Dict[str, D...
  function build_memory_buffer (line 71) | def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[tor...
  function build_memory_reference_from_module (line 100) | def build_memory_reference_from_module(module: torch.nn.Module,
  function build_memory_reference (line 116) | def build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_b...
  class MemoryBufferModuleWrapper (line 143) | class MemoryBufferModuleWrapper:
    method __init__ (line 149) | def __init__(self, module: nn.Module):
    method get_memory_buffers (line 156) | def get_memory_buffers(self):
    method get_weight_buffer_meta (line 159) | def get_weight_buffer_meta(self):
  class MegatronMemoryBufferForRollout (line 163) | class MegatronMemoryBufferForRollout(object):
    method __init__ (line 178) | def __init__(self, transform_memory_param_fn):
    method initialize_weight_buffer (line 184) | def initialize_weight_buffer(self, weight_buffer_meta_pp: List[Dict[st...
    method build_memory_reference (line 202) | def build_memory_reference(self):
    method named_parameters (line 208) | def named_parameters(self):
    method weight_buffers (line 212) | def weight_buffers(self):
    method memory_buffers (line 216) | def memory_buffers(self):

FILE: verl/utils/model.py
  class LambdaLayer (line 28) | class LambdaLayer(nn.Module):
    method __init__ (line 30) | def __init__(self, fn):
    method forward (line 34) | def forward(self, *args, **kwargs):
  function squeeze (line 38) | def squeeze(x):
  function update_model_config (line 42) | def update_model_config(module_config, override_config_kwargs):
  function get_huggingface_actor_config (line 47) | def get_huggingface_actor_config(model_name: str, override_config_kwargs...
  function get_generation_config (line 58) | def get_generation_config(
  function create_huggingface_actor (line 75) | def create_huggingface_actor(model_name: str, override_config_kwargs=Non...
  function create_huggingface_critic (line 98) | def create_huggingface_critic(model_name: str, override_config_kwargs=No...
  function get_model_size (line 119) | def get_model_size(model: nn.Module, scale='auto'):
  function print_model_size (line 146) | def print_model_size(model: nn.Module, name: str = None):
  function create_random_mask (line 153) | def create_random_mask(input_ids: torch.Tensor,
  function compute_position_id_with_mask (line 194) | def compute_position_id_with_mask(mask):
  function normalize_pp_vpp_params (line 198) | def normalize_pp_vpp_params(params, num_hidden_layers, layer_name='layer...
  function get_parallel_model_from_config (line 251) | def get_parallel_model_from_config(config,
  function _get_parallel_model_architecture_from_config (line 269) | def _get_parallel_model_architecture_from_config(config: PretrainedConfi...
  function load_megatron_model_weights (line 279) | def load_megatron_model_weights(config,
  function pad_packed_inputs (line 325) | def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen...

FILE: verl/utils/py_functional.py
  function union_two_dict (line 22) | def union_two_dict(dict1: Dict, dict2: Dict):
  function append_to_dict (line 41) | def append_to_dict(data: Dict, new_data: Dict):
  class NestedNamespace (line 48) | class NestedNamespace(SimpleNamespace):
    method __init__ (line 50) | def __init__(self, dictionary, **kwargs):

FILE: verl/utils/ray_utils.py
  function parallel_put (line 23) | def parallel_put(data_list, max_workers=None):

FILE: verl/utils/rendezvous/ray_backend.py
  class NCCLIDStore (line 25) | class NCCLIDStore:
    method __init__ (line 27) | def __init__(self, nccl_id):
    method get (line 30) | def get(self):
  function get_nccl_id_store_by_name (line 34) | def get_nccl_id_store_by_name(name):
  function create_nccl_communicator_in_ray (line 47) | def create_nccl_communicator_in_ray(rank: int,

FILE: verl/utils/reward_score/__init__.py
  function _default_compute_score (line 17) | def _default_compute_score(data_source, solution_str, ground_truth, extr...

FILE: verl/utils/reward_score/eval.py
  function extract_pattern (line 11) | def extract_pattern(pred: str, pattern: str):
  function extract_split (line 26) | def extract_split(pred: str, split: str):
  function expansion (line 34) | def expansion(answer_list: str):
  function extract (line 53) | def extract(pred: str):
  function normalize_final_answer (line 146) | def normalize_final_answer(final_answer: str) -> str:
  function choice_answer_clean (line 211) | def choice_answer_clean(pred: str):
  function parse_digits (line 225) | def parse_digits(num):
  function is_digit (line 241) | def is_digit(num):
  function str_to_pmatrix (line 246) | def str_to_pmatrix(input_str):
  function math_equal (line 259) | def math_equal(
  function math_equal_process (line 458) | def math_equal_process(param):
  function numeric_equal (line 462) | def numeric_equal(prediction: float, reference: float):
  function symbolic_equal (line 467) | def symbolic_equal(a, b):
  function symbolic_equal_process (line 522) | def symbolic_equal_process(a, b, output_queue):
  function call_with_timeout (line 527) | def call_with_timeout(func, *args, timeout=1, **kwargs):
  function process_answer_list (line 541) | def process_answer_list(answer_list):
  function is_equal (line 559) | def is_equal(pred, gt):
  function exact_match_eval (line 565) | def exact_match_eval(pred, gt):

FILE: verl/utils/reward_score/geo3k.py
  function format_reward (line 19) | def format_reward(predict_str: str) -> float:
  function acc_reward (line 25) | def acc_reward(predict_str: str, ground_truth: str) -> float:
  function compute_score (line 30) | def compute_score(predict_str: str, ground_truth: str) -> float:

FILE: verl/utils/reward_score/gsm8k.py
  function extract_solution (line 18) | def extract_solution(solution_str, method='strict'):
  function compute_score (line 44) | def compute_score(solution_str, ground_truth, method='strict', format_sc...

FILE: verl/utils/reward_score/math.py
  function compute_score (line 17) | def compute_score(solution_str, ground_truth) -> float:
  function is_equiv (line 32) | def is_equiv(str1, str2, verbose=False):
  function remove_boxed (line 49) | def remove_boxed(s):
  function last_boxed_only_string (line 63) | def last_boxed_only_string(string):
  function fix_fracs (line 93) | def fix_fracs(string):
  function fix_a_slash_b (line 125) | def fix_a_slash_b(string):
  function remove_right_units (line 140) | def remove_right_units(string):
  function fix_sqrt (line 150) | def fix_sqrt(string):
  function strip_string (line 165) | def strip_string(string):

FILE: verl/utils/reward_score/math_verifier.py
  class TimeoutException (line 21) | class TimeoutException(Exception):
  function timeout (line 25) | def timeout(seconds):
  function check_mixed_languages (line 42) | def check_mixed_languages(text):
  function undesired_format (line 47) | def undesired_format(text):
  function check_garbled_characters (line 52) | def check_garbled_characters(text):
  function has_repeated_patterns (line 59) | def has_repeated_patterns(text):
  function correctness_score_default (line 62) | def correctness_score_default(response, gt):
  function correctness_score_v2 (line 69) | def correctness_score_v2(response, gt):
  function compute_score (line 75) | def compute_score(solution_str, ground_truth, reward_type) -> float:
  function is_equiv (line 104) | def is_equiv(str1, str2, verbose=False):
  function remove_boxed (line 130) | def remove_boxed(s):
  function last_boxed_only_string (line 144) | def last_boxed_only_string(string):
  function fix_fracs (line 174) | def fix_fracs(string):
  function fix_a_slash_b (line 206) | def fix_a_slash_b(string):
  function remove_right_units (line 221) | def remove_right_units(string):
  function fix_sqrt (line 231) | def fix_sqrt(string):
  function strip_string (line 246) | def strip_string(string):

FILE: verl/utils/reward_score/prime_code/__init__.py
  function compute_score (line 21) | def compute_score(completion, test_cases, continuous=False):

FILE: verl/utils/reward_score/prime_code/testing_util.py
  function truncatefn (line 42) | def truncatefn(s, length=300):
  class CODE_TYPE (line 50) | class CODE_TYPE(Enum):
  class TimeoutException (line 56) | class TimeoutException(Exception):
  function timeout_handler (line 60) | def timeout_handler(signum, frame):
  class Capturing (line 74) | class Capturing(list):
    method __enter__ (line 76) | def __enter__(self):
    method __exit__ (line 83) | def __exit__(self, *args):
  function only_int_check (line 89) | def only_int_check(val):
  function string_int_check (line 93) | def string_int_check(val):
  function combined_int_check (line 97) | def combined_int_check(val):
  function clean_traceback (line 101) | def clean_traceback(error_traceback):
  function run_test (line 108) | def run_test(in_outs, test=None, debug=False, timeout=15):
  function custom_compare_ (line 595) | def custom_compare_(output, ground_truth):
  function stripped_string_compare (line 611) | def stripped_string_compare(s1, s2):
  function call_method (line 617) | def call_method(method, inputs):
  function reliability_guard (line 644) | def reliability_guard(maximum_memory_bytes=None):

FILE: verl/utils/reward_score/prime_code/utils.py
  function _temp_run (line 25) | def _temp_run(sample, generation, debug, result, metadata_list, timeout):
  function check_correctness (line 40) | def check_correctness(in_outs: Optional[dict], generation, timeout=10, d...

FILE: verl/utils/reward_score/prime_math/__init__.py
  function _sympy_parse (line 38) | def _sympy_parse(expr: str):
  function _parse_latex (line 47) | def _parse_latex(expr: str) -> str:
  function _is_float (line 65) | def _is_float(num: str) -> bool:
  function _is_int (line 73) | def _is_int(x: float) -> bool:
  function _is_frac (line 80) | def _is_frac(expr: str) -> bool:
  function _str_is_int (line 84) | def _str_is_int(x: str) -> bool:
  function _str_to_int (line 93) | def _str_to_int(x: str) -> bool:
  function _inject_implicit_mixed_number (line 99) | def _inject_implicit_mixed_number(step: str):
  function _strip_properly_formatted_commas (line 109) | def _strip_properly_formatted_commas(expr: str):
  function _normalize (line 120) | def _normalize(expr: str) -> str:
  function count_unknown_letters_in_expr (line 189) | def count_unknown_letters_in_expr(expr: str):
  function should_allow_eval (line 196) | def should_allow_eval(expr: str):
  function are_equal_under_sympy (line 212) | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized...
  function split_tuple (line 226) | def split_tuple(expr: str):
  function grade_answer (line 241) | def grade_answer(given_answer: str, ground_truth: str) -> bool:
  function remove_boxed (line 295) | def remove_boxed(s):
  function _last_boxed_only_string (line 305) | def _last_boxed_only_string(string):
  function match_answer (line 335) | def match_answer(response):
  function compute_score (line 380) | def compute_score(model_output: str, ground_truth: str) -> bool:

FILE: verl/utils/reward_score/prime_math/grader.py
  function is_digit (line 107) | def is_digit(s):
  function normalize (line 119) | def normalize(answer, pi) -> str:
  function handle_base (line 138) | def handle_base(x) -> str:
  function handle_pi (line 147) | def handle_pi(string, pi):
  function math_equal (line 174) | def math_equal(prediction: Union[bool, float, str],
  function symbolic_equal (line 310) | def symbolic_equal(a, b, tolerance, timeout=10.0):
  class TimeoutException (line 340) | class TimeoutException(Exception):
  function time_limit (line 345) | def time_limit(seconds: float):
  function format_intervals (line 358) | def format_intervals(prediction):

FILE: verl/utils/reward_score/prime_math/math_normalize.py
  function normalize_answer (line 43) | def normalize_answer(answer: Optional[str]) -> Optional[str]:
  function _fix_fracs (line 57) | def _fix_fracs(string):
  function _fix_a_slash_b (line 89) | def _fix_a_slash_b(string):
  function _remove_right_units (line 104) | def _remove_right_units(string):
  function _fix_sqrt (line 114) | def _fix_sqrt(string):
  function _strip_string (line 129) | def _strip_string(string):

FILE: verl/utils/seqlen_balancing.py
  function karmarkar_karp (line 25) | def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size...
  function greedy_partition (line 133) | def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_si...
  function get_seqlen_balanced_partitions (line 152) | def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions:...
  function log_seqlen_unbalance (line 186) | def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[i...
  function ceildiv (line 220) | def ceildiv(a, b):
  function rearrange_micro_batches (line 224) | def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=N...
  function get_reverse_idx (line 259) | def get_reverse_idx(idx_map):

FILE: verl/utils/tokenizer.py
  function set_pad_token_id (line 20) | def set_pad_token_id(tokenizer):
  function hf_tokenizer (line 35) | def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=Tr...
  function hf_processor (line 62) | def hf_processor(name_or_path, **kwargs):

FILE: verl/utils/torch_dtypes.py
  class PrecisionType (line 27) | class PrecisionType(object):
    method supported_type (line 43) | def supported_type(precision: Union[str, int]) -> bool:
    method supported_types (line 47) | def supported_types() -> list[str]:
    method is_fp16 (line 51) | def is_fp16(precision):
    method is_fp32 (line 55) | def is_fp32(precision):
    method is_bf16 (line 59) | def is_bf16(precision):
    method to_dtype (line 63) | def to_dtype(precision):
    method to_str (line 74) | def to_str(precision):

FILE: verl/utils/torch_functional.py
  function gather_from_labels (line 33) | def gather_from_labels(data, label):
  function logprobs_from_logits (line 48) | def logprobs_from_logits(logits, labels):
  function logprobs_from_logits_flash_attn (line 64) | def logprobs_from_logits_flash_attn(logits, labels):
  function logprobs_from_logits_naive (line 71) | def logprobs_from_logits_naive(logits, labels):
  function logprobs_from_logits_v2 (line 77) | def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
  function clip_by_value (line 97) | def clip_by_value(x, tensor_min, tensor_max):
  function entropy_from_logits (line 106) | def entropy_from_logits(logits: torch.Tensor):
  function masked_sum (line 113) | def masked_sum(values, mask, axis=None):
  function masked_mean (line 118) | def masked_mean(values, mask, axis=None):
  function masked_var (line 123) | def masked_var(values, mask, unbiased=True):
  function masked_whiten (line 141) | def masked_whiten(values, mask, shift_mean=True):
  function get_eos_mask (line 150) | def get_eos_mask(response_id: torch.Tensor, eos_token: Union[int, List[i...
  function compute_grad_norm (line 170) | def compute_grad_norm(model: nn.Module):
  function broadcast_dict_tensor (line 179) | def broadcast_dict_tensor(tensors: Union[Dict[str, torch.Tensor], Tensor...
  function allgather_dict_tensors (line 188) | def allgather_dict_tensors(tensors: Union[Dict[str, torch.Tensor], Tenso...
  function split_dict_tensor_into_batches (line 222) | def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> L...
  function pad_2d_list_to_length (line 228) | def pad_2d_list_to_length(response, pad_token_id, max_length=None):
  function pad_sequence_to_length (line 242) | def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=...
  function tokenize_and_postprocess_data (line 258) | def tokenize_and_postprocess_data(prompt: str,
  function remove_pad_token (line 302) | def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tens...
  function log_probs_from_logits_response (line 317) | def log_probs_from_logits_response(input_ids, logits, response_length):
  function log_probs_from_logits_response_rmpad (line 333) | def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logi...
  function log_probs_from_logits_all_rmpad (line 361) | def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indic...
  function post_process_logits (line 392) | def post_process_logits(input_ids, logits, temperature, top_k, top_p):
  function get_cosine_schedule_with_warmup (line 412) | def get_cosine_schedule_with_warmup(
  function get_constant_schedule_with_warmup (line 455) | def get_constant_schedule_with_warmup(
  function prepare_decoder_attention_mask (line 467) | def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_e...
  function _make_causal_mask (line 489) | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, d...
  function _expand_mask (line 502) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
  function get_unpad_data (line 516) | def get_unpad_data(attention_mask):

FILE: verl/utils/tracking.py
  class Tracking (line 24) | class Tracking(object):
    method __init__ (line 27) | def __init__(self, project_name, experiment_name, default_backend: Uni...
    method log (line 92) | def log(self, data, step, backend=None):
    method __del__ (line 97) | def __del__(self):
  class _TensorboardAdapter (line 108) | class _TensorboardAdapter:
    method __init__ (line 110) | def __init__(self):
    method log (line 118) | def log(self, data, step):
    method finish (line 122) | def finish(self):
  class _MlflowLoggingAdapter (line 126) | class _MlflowLoggingAdapter:
    method log (line 128) | def log(self, data, step):
  function _compute_mlflow_params_from_objects (line 133) | def _compute_mlflow_params_from_objects(params) -> Dict[str, Any]:
  function _transform_params_to_json_serializable (line 140) | def _transform_params_to_json_serializable(x, convert_list_to_dict: bool):
  function _flatten_dict (line 160) | def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]:

FILE: verl/utils/ulysses.py
  function set_ulysses_sequence_parallel_group (line 29) | def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
  function get_ulysses_sequence_parallel_group (line 37) | def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
  function get_ulysses_sequence_parallel_world_size (line 45) | def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None)...
  function get_ulysses_sequence_parallel_rank (line 53) | def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
  function gather_seq_scatter_heads (line 61) | def gather_seq_scatter_heads(
  function gather_heads_scatter_seq (line 85) | def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, gro...
  function _pad_tensor (line 103) | def _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
  function _unpad_tensor (line 110) | def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
  function slice_input_tensor (line 116) | def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group:...
  function all_to_all_tensor (line 132) | def all_to_all_tensor(
  function all_gather_tensor (line 154) | def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.Process...
  class SeqAllToAll (line 164) | class SeqAllToAll(torch.autograd.Function):
    method forward (line 167) | def forward(
    method backward (line 182) | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, No...
  class Gather (line 197) | class Gather(torch.autograd.Function):
    method forward (line 200) | def forward(ctx: Any,
    method backward (line 226) | def backward(ctx: Any, grad_output: Tensor) -> Any:
  function gather_outpus_and_unpad (line 233) | def gather_outpus_and_unpad(x: Tensor,
  function ulysses_pad_and_slice_inputs (line 252) | def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor,

FILE: verl/workers/actor/base.py
  class BasePPOActor (line 26) | class BasePPOActor(ABC):
    method __init__ (line 28) | def __init__(self, config):
    method compute_log_prob (line 39) | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
    method update_policy (line 54) | def update_policy(self, data: DataProto) -> Dict:

FILE: verl/workers/actor/dp_actor.py
  class DataParallelPPOActor (line 39) | class DataParallelPPOActor(BasePPOActor):
    method __init__ (line 41) | def __init__(
    method _forward_micro_batch (line 58) | def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torc...
    method _optimizer_step (line 158) | def _optimizer_step(self):
    method compute_log_prob (line 168) | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
    method update_policy (line 226) | def update_policy(self, data: DataProto):

FILE: verl/workers/actor/megatron_actor.py
  class MegatronPPOActor (line 54) | class MegatronPPOActor(BasePPOActor):
    method __init__ (line 56) | def __init__(self, config, model_config, megatron_config: ModelParalle...
    method _validate_config (line 139) | def _validate_config(self, config) -> None:
    method compute_log_prob (line 143) | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
    method make_minibatch_iterator (line 204) | def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataPro...
    method forward_backward_batch (line 232) | def forward_backward_batch(self, data: DataProto, forward_only=False, ...
    method update_policy (line 342) | def update_policy(self, dataloader: Iterable[DataProto]) -> Dict:

FILE: verl/workers/critic/base.py
  class BasePPOCritic (line 26) | class BasePPOCritic(ABC):
    method __init__ (line 28) | def __init__(self, config):
    method compute_values (line 33) | def compute_values(self, data: DataProto) -> torch.Tensor:
    method update_critic (line 38) | def update_critic(self, data: DataProto):

FILE: verl/workers/critic/dp_critic.py
  class DataParallelPPOCritic (line 39) | class DataParallelPPOCritic(BasePPOCritic):
    method __init__ (line 41) | def __init__(self, config, critic_module: nn.Module, critic_optimizer:...
    method _forward_micro_batch (line 50) | def _forward_micro_batch(self, micro_batch):
    method _optimizer_step (line 115) | def _optimizer_step(self):
    method compute_values (line 125) | def compute_values(self, data: DataProto) -> torch.Tensor:
    method update_critic (line 166) | def update_critic(self, data: DataProto):

FILE: verl/workers/critic/megatron_critic.py
  class MegatronPPOCritic (line 43) | class MegatronPPOCritic(BasePPOCritic):
    method __init__ (line 45) | def __init__(self, config, model_config, megatron_config, critic_modul...
    method _validate_config (line 79) | def _validate_config(self, config) -> None:
    method compute_values (line 83) | def compute_values(self, data: DataProto) -> DataProto:
    method make_minibatch_iterator (line 112) | def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataPro...
    method forward_backward_batch (line 119) | def forward_backward_batch(self, data: DataProto, forward_only=False):
    method update_critic (line 207) | def update_critic(self, dataloader: Iterable[DataProto]):

FILE: verl/workers/fsdp_workers.py
  function create_device_mesh (line 48) | def create_device_mesh(world_size, fsdp_size):
  function get_sharding_strategy (line 61) | def get_sharding_strategy(device_mesh):
  class ActorRolloutRefWorker (line 72) | class ActorRolloutRefWorker(Worker):
    method __init__ (line 78) | def __init__(self, config: DictConfig, role: str):
    method _build_model_optimizer (line 142) | def _build_model_optimizer(self,
    method _build_rollout (line 296) | def _build_rollout(self):
    method init_model (line 345) | def init_model(self):
    method update_actor (line 418) | def update_actor(self, data: DataProto):
    method generate_sequences (line 461) | def generate_sequences(self, prompts: DataProto):
    method compute_log_prob (line 503) | def compute_log_prob(self, data: DataProto):
    method compute_ref_log_prob (line 537) | def compute_ref_log_prob(self, data: DataProto):
    method save_checkpoint (line 564) | def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, r...
    method load_checkpoint (line 581) | def load_checkpoint(self, path, del_local_after_load=False):
  class CriticWorker (line 594) | class CriticWorker(Worker):
    method __init__ (line 596) | def __init__(self, config):
    method _build_critic_model_optimizer (line 638) | def _build_critic_model_optimizer(self, config):
    method init_model (line 757) | def init_model(self):
    method compute_values (line 784) | def compute_values(self, data: DataProto):
    method update_critic (line 806) | def update_critic(self, data: DataProto):
    method save_checkpoint (line 841) | def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, r...
    method load_checkpoint (line 856) | def load_checkpoint(self, path, del_local_after_load=True):
  class RewardModelWorker (line 872) | class RewardModelWorker(Worker):
    method __init__ (line 877) | def __init__(self, config):
    method _build_model (line 908) | def _build_model(self, config):
    method init_model (line 970) | def init_model(self):
    method _forward_micro_batch (line 976) | def _forward_micro_batch(self, micro_batch):
    method _expand_to_token_level (line 1030) | def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
    method _switch_chat_template (line 1045) | def _switch_chat_template(self, data: DataProto):
    method compute_rm_score (line 1103) | def compute_rm_score(self, data: DataProto):

FILE: verl/workers/megatron_workers.py
  function set_random_seed (line 50) | def set_random_seed(seed):
  class ActorRolloutRefWorker (line 66) | class ActorRolloutRefWorker(MegatronWorker):
    method __init__ (line 72) | def __init__(self, config: DictConfig, role: str):
    method _build_model_optimizer (line 134) | def _build_model_optimizer(self,
    method _build_rollout (line 231) | def _build_rollout(self):
    method init_model (line 277) | def init_model(self):
    method update_actor (line 345) | def update_actor(self, data: DataProto):
    method generate_sequences (line 369) | def generate_sequences(self, prompts: DataProto):
    method compute_ref_log_prob (line 399) | def compute_ref_log_prob(self, data: DataProto):
    method compute_log_prob (line 418) | def compute_log_prob(self, data: DataProto):
    method load_checkpoint (line 434) | def load_checkpoint(self, checkpoint_path, **kwargs):
    method load_pretrained_model (line 438) | def load_pretrained_model(self, checkpoint_path, **kwargs):
    method save_checkpoint (line 442) | def save_checkpoint(self, checkpoint_path, **kwargs):
  class CriticWorker (line 447) | class CriticWorker(MegatronWorker):
    method __init__ (line 449) | def __init__(self, config):
    method _build_critic_model_optimizer (line 487) | def _build_critic_model_optimizer(self,
    method init_model (line 556) | def init_model(self):
    method compute_values (line 594) | def compute_values(self, data: DataProto):
    method update_critic (line 602) | def update_critic(self, data: DataProto):
    method load_checkpoint (line 616) | def load_checkpoint(self, checkpoint_path, **kwargs):
    method save_checkpoint (line 620) | def save_checkpoint(self, checkpoint_path, **kwargs):
  class RewardModelWorker (line 624) | class RewardModelWorker(MegatronWorker):
    method __init__ (line 629) | def __init__(self, config):
    method _build_rm_model (line 664) | def _build_rm_model(self, model_path, megatron_config: ModelParallelCo...
    method init_model (line 723) | def init_model(self):
    method compute_rm_score (line 774) | def compute_rm_score(self, data: DataProto):

FILE: verl/workers/reward_manager/naive.py
  class NaiveRewardManager (line 20) | class NaiveRewardManager:
    method __init__ (line 24) | def __init__(self, config, tokenizer, num_examine, compute_score=None)...
    method __call__ (line 30) | def __call__(self, data: DataProto):

FILE: verl/workers/reward_manager/prime.py
  function single_compute_score (line 25) | async def single_compute_score(evaluation_func, completion, reference, t...
  function parallel_compute_score_async (line 46) | async def parallel_compute_score_async(evaluation_func,
  class PrimeRewardManager (line 84) | class PrimeRewardManager:
    method __init__ (line 89) | def __init__(self, tokenizer, num_examine, compute_score=None) -> None:
    method __call__ (line 94) | def __call__(self, data: DataProto):

FILE: verl/workers/reward_model/base.py
  class BasePPORewardModel (line 23) | class BasePPORewardModel(ABC):
    method __init__ (line 25) | def __init__(self, config):
    method compute_reward (line 29) | def compute_reward(self, data: DataProto) -> DataProto:

FILE: verl/workers/reward_model/megatron/reward_model.py
  class MegatronRewardModel (line 32) | class MegatronRewardModel(BasePPORewardModel):
    method __init__ (line 34) | def __init__(self,
    method re_encode_by_rm_tokenizer (line 53) | def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto:
    method compute_reward (line 118) | def compute_reward(self, data: DataProto) -> DataProto:
    method forward_batch (line 180) | def forward_batch(self, data: DataProto):
    method offload_params_to_cpu (line 254) | def offload_params_to_cpu(self):
    method load_params_to_cuda (line 262) | def load_params_to_cuda(self):

FILE: verl/workers/rollout/base.py
  class BaseRollout (line 23) | class BaseRollout(ABC):
    method __init__ (line 25) | def __init__(self):
    method generate_sequences (line 35) | def generate_sequences(self, prompts: DataProto) -> DataProto:

FILE: verl/workers/rollout/hf_rollout.py
  class HFRollout (line 35) | class HFRollout(BaseRollout):
    method __init__ (line 37) | def __init__(self, module: nn.Module, config):
    method generate_sequences (line 42) | def generate_sequences(self, prompts: DataProto) -> DataProto:
    method _generate_minibatch (line 51) | def _generate_minibatch(self, prompts: DataProto) -> DataProto:

FILE: verl/workers/rollout/naive/naive_rollout.py
  class NaiveRollout (line 36) | class NaiveRollout(BaseRollout):
    method __init__ (line 38) | def __init__(self, module: nn.Module, config):
    method generate_sequences (line 52) | def generate_sequences(self, prompts: DataProto) -> DataProto:

FILE: verl/workers/rollout/tokenizer.py
  class HybridEngineBaseTokenizer (line 23) | class HybridEngineBaseTokenizer(ABC):
    method vocab_size (line 28) | def vocab_size(self):
    method pad_token_id (line 36) | def pad_token_id(self):
    method eos_token_id (line 44) | def eos_token_id(self):
    method all_special_ids (line 53) | def all_special_ids(self) -> List[int]:
    method all_special_tokens (line 61) | def all_special_tokens(self) -> List[str]:
    method encode (line 70) | def encode(self, text):
    method decode (line 86) | def decode(
    method convert_ids_to_tokens (line 116) | def convert_ids_to_tokens(self,
    method get_added_vocab (line 135) | def get_added_vocab(self) -> Dict[str, int]:
    method convert_tokens_to_string (line 147) | def convert_tokens_to_string(self, tokens: List[str]) -> str:
    method is_fast (line 161) | def is_fast(self):

FILE: verl/workers/rollout/vllm_rollout/__init__.py
  function get_version (line 18) | def get_version(pkg):

FILE: verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py
  function _pre_process_inputs (line 50) | def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) ->...
  class FIREvLLMRollout (line 58) | class FIREvLLMRollout(vLLMRollout):
    method __init__ (line 60) | def __init__(self, actor_module: nn.Module, config: DictConfig, tokeni...
    method update_sampling_params (line 83) | def update_sampling_params(self, **kwargs):
    method generate_sequences (line 110) | def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:

FILE: verl/workers/rollout/vllm_rollout/qwen_agent/code/code_interpreter.py
  function fix_matplotlib_cjk_font_issue (line 36) | def fix_matplotlib_cjk_font_issue():
  function start_kernel (line 44) | def start_kernel(pid):
  function escape_ansi (line 91) | def escape_ansi(line):
  function publish_image_to_local (line 96) | def publish_image_to_local(image_base64: str):
  function code_interpreter (line 140) | def code_interpreter(action_input_list: list, timeout=30, clear=False):
  function _code_interpreter (line 157) | def _code_interpreter(code: str, timeout, clear=False):
  function get_multiline_input (line 226) | def get_multiline_input(hint):

FILE: verl/workers/rollout/vllm_rollout/qwen_agent/code/utils/code_utils.py
  function replace_upload_fname (line 7) | def replace_upload_fname(text, upload_fname_list):
  function extract_code (line 14) | def extract_code(text):

FILE: verl/workers/rollout/vllm_rollout/qwen_agent/llm/schema.py
  class BaseModelCompatibleDict (line 22) | class BaseModelCompatibleDict(BaseModel):
    method __getitem__ (line 24) | def __getitem__(self, item):
    method __setitem__ (line 27) | def __setitem__(self, key, value):
    method model_dump (line 30) | def model_dump(self, **kwargs):
    method model_dump_json (line 35) | def model_dump_json(self, **kwargs):
    method get (line 40) | def get(self, key, default=None):
    method __str__ (line 50) | def __str__(self):
  class FunctionCall (line 54) | class FunctionCall(BaseModelCompatibleDict):
    method __init__ (line 58) | def __init__(self, name: str, arguments: str):
    method __repr__ (line 61) | def __repr__(self):
  class ContentItem (line 65) | class ContentItem(BaseModelCompatibleDict):
    method __init__ (line 72) | def __init__(self,
    method check_exclusivity (line 81) | def check_exclusivity(self):
    method __repr__ (line 98) | def __repr__(self):
    method get_type_and_value (line 101) | def get_type_and_value(self) -> Tuple[Literal['text', 'image', 'file',...
    method type (line 107) | def type(self) -> Literal['text', 'image', 'file', 'audio', 'video']:
    method value (line 112) | def value(self) -> str:
  class Message (line 117) | class Message(BaseModelCompatibleDict):
    method __init__ (line 124) | def __init__(self,
    method __repr__ (line 135) | def __repr__(self):
    method role_checker (line 139) | def role_checker(cls, value: str) -> str:

FILE: verl/workers/rollout/vllm_rollout/qwen_agent/log.py
  function setup_logger (line 5) | def setup_logger(level=None):

FILE: verl/workers/rollout/vllm_rollout/qwen_agent/tools/base.py
  class ToolServiceError (line 13) | class ToolServiceError(Exception):
    method __init__ (line 15) | def __init__(self,
  function register_tool (line 30) | def register_tool(name, allow_overwrite=False):
  function is_tool_schema (line 48) | def is_tool_schema(obj: dict) -> bool:
  class BaseTool (line 95) | class BaseTool(ABC):
    method __init__ (line 100) | def __init__(self, cfg: Optional[dict] = None):
    method call (line 112) | def call(self, params: Union[str, dict], **kwargs) -> Union[str, list,...
    method _verify_json_format_args (line 126) | def _verify_json_format_args(self, params: Union[str, dict], strict_js...
    method function (line 151) | def function(self) -> dict:  # Bad naming. It should be `function_info`.
    method name_for_human (line 161) | def name_for_human(self) -> str:
    method args_format (line 165) | def args_format(self) -> str:
    method file_access (line 175) | def file_access(self) -> bool:
  class BaseToolWithFileAccess (line 179) | class BaseToolWithFileAccess(BaseTool, ABC):
    method __init__ (line 181) | def __init__(self, cfg: Optional[Dict] = None):
    method file_access (line 188) | def file_access(self) -> bool:
    method call (line 191) | def call(self, params: Union[str, dict], files: List[str] = None, **kw...

FILE: verl/workers/rollout/vllm_rollout/qwen_agent/tools/code_interpreter.py
  function _kill_kernels_and_subprocesses (line 39) | def _kill_kernels_and_subprocesses(_sig_num=None, _frame=None):
  class CodeInterpreter (line 61) | class CodeInterpreter(BaseToolWithFileAccess):
    method __init__ (line 65) | def __init__(self, cfg: Optional[Dict] = None):
    method args_format (line 73) | def args_format(self) -> str:
    method call (line 82) | def call(self, params: Union[str, dict], files: List[str] = None, time...
    method __del__ (line 127) | def __del__(self):
    method _fix_secure_write_for_code_interpreter (line 137) | def _fix_secure_write_for_code_interpreter(self):
    method _start_kernel (line 151) | def _start_kernel(self, kernel_id: str):
    method _execute_code (line 199) | def _execute_code(self, kc, code: str) -> str:
    method _serve_image (line 253) | def _serve_image(self, image_base64: str) -> str:
  function _check_deps_for_code_interpreter (line 270) | def _check_deps_for_code_interpreter():
  function _fix_matplotlib_cjk_font_issue (line 286) | def _fix_matplotlib_cjk_font_issue():
  function _escape_ansi (line 305) | def _escape_ansi(line: str) -> str:
  class AnyThreadEventLoopPolicy (line 321) | class AnyThreadEventLoopPolicy(_BasePolicy):  # type: ignore
    method get_event_loop (line 334) | def get_event_loop(self) -> asyncio.AbstractEventLoop:

FILE: verl/workers/rollout/vllm_rollout/qwen_agent/tools/python_executor.py
  class GenericRuntime (line 21) | class GenericRuntime:
    method __init__ (line 26) | def __init__(self):
    method exec_code (line 33) | def exec_code(self, code_piece: str) -> None:
    method eval_code (line 38) | def eval_code(self, expr: str) -> Any:
    method inject (line 41) | def inject(self, var_dict: Dict[str, Any]) -> None:
    method answer (line 46) | def answer(self):
  class DateRuntime (line 50) | class DateRuntime(GenericRuntime):
  class CustomDict (line 59) | class CustomDict(dict):
    method __iter__ (line 61) | def __iter__(self):
  class ColorObjectRuntime (line 65) | class ColorObjectRuntime(GenericRuntime):
  function _check_deps_for_python_executor (line 69) | def _check_deps_for_python_executor():
  class PythonExecutor (line 83) | class PythonExecutor(BaseTool):
    method __init__ (line 88) | def __init__(self, cfg: Optional[Dict] = None):
    method call (line 107) | def call(self, params: Union[str, dict], **kwargs) -> list:
    method apply (line 120) | def apply(self, code: str) -> list:
    method process_generation_to_code (line 123) | def process_generation_to_code(self, gens: str):
    method execute (line 127) | def execute(
    method truncate (line 161) | def truncate(s, max_length=256):
    method batch_apply (line 167) | def batch_apply(self, batch_code: List[str]) -> list:
  function _test (line 219) | def _test():

FILE: verl/workers/rollout/vllm_rollout/qwen_agent/utils/utils.py
  function append_signal_handler (line 25) | def append_signal_handler(sig, handler):
  function get_local_ip (line 51) | def get_local_ip() -> str:
  function hash_sha256 (line 64) | def hash_sha256(text: str) -> str:
  function print_traceback (line 70) | def print_traceback(is_error: bool = True):
  function has_chinese_chars (line 81) | def has_chinese_chars(data: Any) -> bool:
  function has_chinese_messages (line 86) | def has_chinese_messages(messages: List[Union[Message, dict]], check_rol...
  function get_basename_from_url (line 94) | def get_basename_from_url(path_or_url: str) -> str:
  function is_http_url (line 114) | def is_http_url(path_or_url: str) -> bool:
  function is_image (line 120) | def is_image(path_or_url: str) -> bool:
  function sanitize_chrome_file_path (line 128) | def sanitize_chrome_file_path(file_path: str) -> str:
  function sanitize_windows_file_path (line 142) | def sanitize_windows_file_path(file_path: str) -> str:
  function save_url_to_local_work_dir (line 168) | def save_url_to_local_work_dir(url: str, save_dir: str, save_filename: s...
  function save_text_to_file (line 195) | def save_text_to_file(path: str, text: str) -> None:
  function read_text_from_file (line 200) | def read_text_from_file(path: str) -> str:
  function contains_html_tags (line 212) | def contains_html_tags(text: str) -> bool:
  function get_content_type_by_head_request (line 217) | def get_content_type_by_head_request(path: str) -> str:
  function get_file_type (line 226) | def get_file_type(path: str) -> Literal['pdf', 'docx', 'pptx', 'txt', 'h...
  function extract_urls (line 258) | def extract_urls(text: str) -> List[str]:
  function extract_markdown_urls (line 264) | def extract_markdown_urls(md_text: str) -> List[str]:
  function extract_code (line 270) | def extract_code(text: str) -> str:
  function json_loads (line 284) | def json_loads(text: str) -> dict:
  class PydanticJSONEncoder (line 297) | class PydanticJSONEncoder(json.JSONEncoder):
    method default (line 299) | def default(self, obj):
  function json_dumps_pretty (line 305) | def json_dumps_pretty(obj: dict, ensure_ascii=False, indent=2, **kwargs)...
  function json_dumps_compact (line 309) | def json_dumps_compact(obj: dict, ensure_ascii=False, indent=None, **kwa...
  function format_as_multimodal_message (line 313) | def format_as_multimodal_message(
  function format_as_text_message (line 378) | def format_as_text_message(
  function extract_text_from_message (line 395) | def extract_text_from_message(
  function extract_files_from_messages (line 409) | def extract_files_from_messages(messages: List[Message], include_images:...
  function merge_generate_cfgs (line 421) | def merge_generate_cfgs(base_generate_cfg: Optional[dict], new_generate_...
  function build_text_completion_prompt (line 434) | def build_text_completion_prompt(
  function encode_image_as_base64 (line 480) | def encode_image_as_base64(path: str, max_short_side_length: int = -1) -...
  function load_image_from_base64 (line 495) | def load_image_from_base64(image_base64: Union[bytes, str]):
  function resize_image (line 502) | def resize_image(img, short_side_length: int = 1080):
  function get_last_usr_msg_idx (line 519) | def get_last_usr_msg_idx(messages: List[Union[dict, Message]]) -> int:

FILE: verl/workers/rollout/vllm_rollout/vllm_rollout.py
  function _pre_process_inputs (line 49) | def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) ->...
  class vLLMRollout (line 57) | class vLLMRollout(BaseRollout):
    method __init__ (line 59) | def __init__(self, actor_module: nn.Module, config: DictConfig, tokeni...
    method update_sampling_params (line 141) | def update_sampling_params(self, **kwargs):
    method generate_sequences (line 157) | def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:

FILE: verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
  function _pre_process_inputs (line 62) | def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) ->...
  function _repeat_interleave (line 70) | def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: ...
  function extract_program (line 79) | def extract_program(result: str, last_only=True):
  function _detect_tool (line 101) | def _detect_tool(text: str) -> Tuple[bool, str, str, str]:
  function send_request (line 107) | def send_request(json_data):
  class vLLMRollout (line 117) | class vLLMRollout(BaseRollout):
    method __init__ (line 119) | def __init__(self, model_path: str, config: DictConfig, tokenizer, mod...
    method _get_prompts_and_indices (line 195) | def _get_prompts_and_indices(self, samples_info):
    method code_interpreter_batch_call (line 214) | def code_interpreter_batch_call(self, tool_inputs, timeout=20):
    method _tokenize_and_find_mask_token_indices (line 238) | def _tokenize_and_find_mask_token_indices(self, sample_info):
    method _tir_generate (line 259) | def _tir_generate(self, prompts=None, sampling_params=None, prompt_tok...
    method update_sampling_params (line 373) | def update_sampling_params(self, **kwargs):
    method generate_sequences (line 389) | def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:

FILE: verl/workers/sharding_manager/base.py
  class BaseShardingManager (line 21) | class BaseShardingManager:
    method __enter__ (line 23) | def __enter__(self):
    method __exit__ (line 26) | def __exit__(self, exc_type, exc_value, traceback):
    method preprocess_data (line 29) | def preprocess_data(self, data: DataProto) -> DataProto:
    method postprocess_data (line 32) | def postprocess_data(self, data: DataProto) -> DataProto:

FILE: verl/workers/sharding_manager/fsdp_ulysses.py
  class FSDPUlyssesShardingManager (line 31) | class FSDPUlyssesShardingManager(BaseShardingManager):
    method __init__ (line 36) | def __init__(self, device_mesh: DeviceMesh):
    method __enter__ (line 41) | def __enter__(self):
    method __exit__ (line 49) | def __exit__(self, exc_type, exc_value, traceback):
    method preprocess_data (line 56) | def preprocess_data(self, data: DataProto) -> DataProto:
    method postprocess_data (line 78) | def postprocess_data(self, data: DataProto) -> DataProto:

FILE: verl/workers/sharding_manager/fsdp_vllm.py
  class FSDPVLLMShardingManager (line 36) | class FSDPVLLMShardingManager(BaseShardingManager):
    method __init__ (line 38) | def __init__(self,
    method __enter__ (line 71) | def __enter__(self):
    method __exit__ (line 105) | def __exit__(self, exc_type, exc_value, traceback):
    method preprocess_data (line 128) | def preprocess_data(self, data: DataProto) -> DataProto:
    method postprocess_data (line 146) | def postprocess_data(self, data: DataProto) -> DataProto:

FILE: verl/workers/sharding_manager/megatron_vllm.py
  class AllGatherPPModel (line 37) | class AllGatherPPModel:
    method __init__ (line 39) | def __init__(self, model_provider) -> None:
    method _build_param_buffer (line 84) | def _build_param_buffer(self, pp_rank):
    method _build_param_references (line 101) | def _build_param_references(self, pp_rank, maintain_weight=False):
    method _load_params_to_cuda (line 107) | def _load_params_to_cuda(self, pp_rank, to_empty=False):
    method _offload_params_to_cpu (line 117) | def _offload_params_to_cpu(self, pp_rank, to_empty=False):
    method load_params_to_cuda (line 127) | def load_params_to_cuda(self, to_empty=False):
    method allgather_params (line 133) | def allgather_params(self):
    method forward (line 143) | def forward(self, *inputs, **kwargs):
    method __call__ (line 162) | def __call__(self, *inputs, **kwargs):
    method eval (line 165) | def eval(self):
    method train (line 169) | def train(self):
    method offload_params_to_cpu (line 173) | def offload_params_to_cpu(self, to_empty=False):
    method get_all_params (line 179) | def get_all_params(self):
    method update_this_rank_models (line 202) | def update_this_rank_models(self, new_models):
    method this_rank_models (line 207) | def this_rank_models(self):
    method pp_size (line 211) | def pp_size(self):
    method pp_rank (line 215) | def pp_rank(self):
    method pp_group (line 219) | def pp_group(self):
    method pp_models (line 223) | def pp_models(self):
  class MegatronVLLMShardingManager (line 254) | class MegatronVLLMShardingManager(BaseShardingManager):
    method __init__ (line 256) | def __init__(self, module: AllGatherPPModel, inference_engine: LLM, mo...
    method default_tp_concat_fn (line 283) | def default_tp_concat_fn(self, name, param, infer_params, model_config):
    method _post_process_params (line 334) | def _post_process_params(self, params):
    method __enter__ (line 361) | def __enter__(self):
    method __exit__ (line 376) | def __exit__(self, exc_type, exc_value, traceback):
    method preprocess_data (line 394) | def preprocess_data(self, data: DataProto) -> DataProto:
    method postprocess_data (line 410) | def postprocess_data(self, data: DataProto) -> DataProto:
  function get_micro_data_parallel_group (line 434) | def get_micro_data_parallel_group():
  function get_micro_data_parallel_world_size (line 439) | def get_micro_data_parallel_world_size():
  function get_micro_data_parallel_rank (line 443) | def get_micro_data_parallel_rank():
Condensed preview — 221 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,781K chars).
[
  {
    "path": "readme.md",
    "chars": 7102,
    "preview": "# \n<div align=\"center\">\n\n# ToRL: Scaling Tool-Integrated RL\n\n</div>\n\n<p align=\"center\">\n  📄 <a href=\"https://arxiv.org/p"
  },
  {
    "path": "requirements.txt",
    "chars": 3350,
    "preview": "accelerate==1.5.2\naiohappyeyeballs==2.6.1\naiohttp==3.11.14\naiohttp-cors==0.8.0\naiosignal==1.3.2\nairportsdata==20250224\na"
  },
  {
    "path": "scripts/torl_1.5b.sh",
    "chars": 2457,
    "preview": "policy_path=Qwen/Qwen2.5-Math-1.5B\nrollout_batch_size=128\nn_samples_per_prompts=16\nepisode=300\ntemperature=1.0\nbatch_siz"
  },
  {
    "path": "verl/__init__.py",
    "chars": 1393,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/README.md",
    "chars": 1744,
    "preview": "# Models\nCommon modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Followin"
  },
  {
    "path": "verl/models/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/llama/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/llama/megatron/__init__.py",
    "chars": 944,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/llama/megatron/checkpoint_utils/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/llama/megatron/checkpoint_utils/llama_loader.py",
    "chars": 19729,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/llama/megatron/checkpoint_utils/llama_saver.py",
    "chars": 18294,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/llama/megatron/layers/__init__.py",
    "chars": 838,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/llama/megatron/layers/parallel_attention.py",
    "chars": 20129,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/models/llama/megatron/layers/parallel_decoder.py",
    "chars": 6036,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/models/llama/megatron/layers/parallel_linear.py",
    "chars": 2787,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/models/llama/megatron/layers/parallel_mlp.py",
    "chars": 3394,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/models/llama/megatron/layers/parallel_rmsnorm.py",
    "chars": 1860,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/llama/megatron/modeling_llama_megatron.py",
    "chars": 29698,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/models/qwen2/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/qwen2/megatron/__init__.py",
    "chars": 944,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/qwen2/megatron/checkpoint_utils/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py",
    "chars": 20686,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py",
    "chars": 18238,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/qwen2/megatron/layers/__init__.py",
    "chars": 838,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/qwen2/megatron/layers/parallel_attention.py",
    "chars": 18750,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/models/qwen2/megatron/layers/parallel_decoder.py",
    "chars": 6036,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/models/qwen2/megatron/layers/parallel_linear.py",
    "chars": 2787,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/models/qwen2/megatron/layers/parallel_mlp.py",
    "chars": 3394,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/models/qwen2/megatron/layers/parallel_rmsnorm.py",
    "chars": 1860,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/qwen2/megatron/modeling_qwen2_megatron.py",
    "chars": 32421,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/models/registry.py",
    "chars": 3114,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/transformers/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/transformers/llama.py",
    "chars": 10213,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/transformers/monkey_patch.py",
    "chars": 3494,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/transformers/qwen2.py",
    "chars": 10104,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/transformers/qwen2_vl.py",
    "chars": 13065,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/models/weight_loader_registry.py",
    "chars": 1322,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/protocol.py",
    "chars": 24739,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/__init__.py",
    "chars": 939,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/base/__init__.py",
    "chars": 772,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/base/decorator.py",
    "chars": 15531,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/base/megatron/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/base/megatron/worker.py",
    "chars": 1551,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/base/megatron/worker_group.py",
    "chars": 2080,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/base/register_center/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/base/register_center/ray.py",
    "chars": 939,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/base/worker.py",
    "chars": 6318,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/base/worker_group.py",
    "chars": 7559,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/ray/__init__.py",
    "chars": 701,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/ray/base.py",
    "chars": 19215,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/single_controller/ray/megatron.py",
    "chars": 3036,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/third_party/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/third_party/vllm/__init__.py",
    "chars": 2144,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/third_party/vllm/vllm_spmd/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py",
    "chars": 17637,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/arg_utils.py",
    "chars": 11957,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/config.py",
    "chars": 25666,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/llm.py",
    "chars": 13604,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/llm_engine_sp.py",
    "chars": 35525,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/model_loader.py",
    "chars": 12265,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/model_runner.py",
    "chars": 13381,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/parallel_state.py",
    "chars": 6075,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.co"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py",
    "chars": 3047,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py",
    "chars": 4215,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_3_1/worker.py",
    "chars": 13214,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/arg_utils.py",
    "chars": 16002,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/config.py",
    "chars": 9527,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/dtensor_weight_loaders.py",
    "chars": 12378,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py",
    "chars": 4032,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/llm.py",
    "chars": 15147,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/llm_engine_sp.py",
    "chars": 12403,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/megatron_weight_loaders.py",
    "chars": 15726,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/model_loader.py",
    "chars": 13600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/model_runner.py",
    "chars": 12581,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/parallel_state.py",
    "chars": 12798,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.co"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/spmd_gpu_executor.py",
    "chars": 8647,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py",
    "chars": 3310,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_4_2/worker.py",
    "chars": 13310,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py",
    "chars": 23746,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/config.py",
    "chars": 11948,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py",
    "chars": 15827,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py",
    "chars": 1895,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/llm.py",
    "chars": 12271,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py",
    "chars": 14689,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py",
    "chars": 13939,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/model_loader.py",
    "chars": 14974,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/model_runner.py",
    "chars": 6857,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py",
    "chars": 13091,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.co"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py",
    "chars": 10184,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py",
    "chars": 3310,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_5_4/worker.py",
    "chars": 15305,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py",
    "chars": 3142,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/config.py",
    "chars": 4167,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py",
    "chars": 17356,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py",
    "chars": 1773,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/llm.py",
    "chars": 10236,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py",
    "chars": 17823,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py",
    "chars": 14484,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/model_loader.py",
    "chars": 14439,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/model_runner.py",
    "chars": 7587,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py",
    "chars": 12915,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.co"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py",
    "chars": 10115,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py",
    "chars": 1649,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/third_party/vllm/vllm_v_0_6_3/worker.py",
    "chars": 15138,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache Licens"
  },
  {
    "path": "verl/trainer/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/trainer/config/evaluation.yaml",
    "chars": 160,
    "preview": "data:\n  path: /tmp/math_Qwen2-7B-Instruct.parquet\n  prompt_key: prompt\n  response_key: responses\n  data_source_key: data"
  },
  {
    "path": "verl/trainer/config/generation.yaml",
    "chars": 2031,
    "preview": "trainer:\n  nnodes: 1\n  n_gpus_per_node: 8\n\ndata:\n  path: ~/data/rlhf/math/test.parquet\n  prompt_key: prompt\n  n_samples:"
  },
  {
    "path": "verl/trainer/config/ppo_megatron_trainer.yaml",
    "chars": 5490,
    "preview": "data:\n  tokenizer: null\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  pro"
  },
  {
    "path": "verl/trainer/config/ppo_trainer.yaml",
    "chars": 6717,
    "preview": "data:\n  tokenizer: null\n  train_files: ~/data/rlhf/gsm8k/train.parquet\n  val_files: ~/data/rlhf/gsm8k/test.parquet\n  pro"
  },
  {
    "path": "verl/trainer/config/sft_trainer.yaml",
    "chars": 1269,
    "preview": "data:\n  train_batch_size: 256\n  micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu\n  micro_batch_"
  },
  {
    "path": "verl/trainer/fsdp_sft_trainer.py",
    "chars": 26426,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/trainer/main_eval.py",
    "chars": 2180,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/trainer/main_generation.py",
    "chars": 5742,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/trainer/main_ppo.py",
    "chars": 5611,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/trainer/ppo/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/trainer/ppo/core_algos.py",
    "chars": 13841,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Li"
  },
  {
    "path": "verl/trainer/ppo/ray_trainer.py",
    "chars": 53419,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/trainer/runtime_env.yaml",
    "chars": 122,
    "preview": "working_dir: ./\nexcludes: [\"/.git/\"]\nenv_vars:\n  TORCH_NCCL_AVOID_RECORD_STREAMS: \"1\"\n  VLLM_ATTENTION_BACKEND: \"XFORMER"
  },
  {
    "path": "verl/utils/__init__.py",
    "chars": 703,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/checkpoint/__init__.py",
    "chars": 599,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/checkpoint/checkpoint_manager.py",
    "chars": 4678,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/checkpoint/fsdp_checkpoint_manager.py",
    "chars": 7182,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/config.py",
    "chars": 839,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/dataset/README.md",
    "chars": 797,
    "preview": "# Dataset Format\n## RLHF dataset\nWe combine all the data sources into a single parquet files. We directly organize the p"
  },
  {
    "path": "verl/utils/dataset/__init__.py",
    "chars": 707,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/dataset/rl_dataset.py",
    "chars": 10910,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/dataset/rm_dataset.py",
    "chars": 5492,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/dataset/sft_dataset.py",
    "chars": 6893,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/debug/__init__.py",
    "chars": 646,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/debug/performance.py",
    "chars": 1214,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/debug/trajectory_tracker.py",
    "chars": 3174,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/distributed.py",
    "chars": 1124,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/flops_counter.py",
    "chars": 4963,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/fs.py",
    "chars": 3140,
    "preview": "#!/usr/bin/env python\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Versi"
  },
  {
    "path": "verl/utils/fsdp_utils.py",
    "chars": 13578,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/hdfs_io.py",
    "chars": 4562,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/import_utils.py",
    "chars": 1351,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/logger/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/logger/aggregate_logger.py",
    "chars": 1400,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/logging_utils.py",
    "chars": 843,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/megatron/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/megatron/memory.py",
    "chars": 1534,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/megatron/optimizer.py",
    "chars": 1569,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n#"
  },
  {
    "path": "verl/utils/megatron/pipeline_parallel.py",
    "chars": 2119,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n#"
  },
  {
    "path": "verl/utils/megatron/sequence_parallel.py",
    "chars": 1810,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n#"
  },
  {
    "path": "verl/utils/megatron/tensor_parallel.py",
    "chars": 7257,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n#"
  },
  {
    "path": "verl/utils/megatron_utils.py",
    "chars": 11725,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n#"
  },
  {
    "path": "verl/utils/memory_buffer.py",
    "chars": 8093,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/model.py",
    "chars": 13759,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/py_functional.py",
    "chars": 1656,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/ray_utils.py",
    "chars": 1401,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/rendezvous/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/rendezvous/ray_backend.py",
    "chars": 2720,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/reward_score/__init__.py",
    "chars": 1904,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/reward_score/eval.py",
    "chars": 16613,
    "preview": "import os\nimport json\nimport jsonlines\nimport re\nimport copy\n\nPATTERNS=[\n    r\"(?i)Answer\\s*:\\s*([^\\n]+)\",\n    r\"\\\\boxed"
  },
  {
    "path": "verl/utils/reward_score/geo3k.py",
    "chars": 1224,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/reward_score/gsm8k.py",
    "chars": 2385,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/reward_score/math.py",
    "chars": 6717,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/utils/reward_score/math_verifier.py",
    "chars": 11259,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rig"
  },
  {
    "path": "verl/utils/reward_score/prime_code/__init__.py",
    "chars": 3151,
    "preview": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "verl/utils/reward_score/prime_code/testing_util.py",
    "chars": 28040,
    "preview": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "verl/utils/reward_score/prime_code/utils.py",
    "chars": 2252,
    "preview": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "verl/utils/reward_score/prime_math/__init__.py",
    "chars": 12179,
    "preview": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "verl/utils/reward_score/prime_math/grader.py",
    "chars": 14508,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "verl/utils/reward_score/prime_math/math_normalize.py",
    "chars": 6271,
    "preview": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "verl/utils/seqlen_balancing.py",
    "chars": 10327,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/tokenizer.py",
    "chars": 3100,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/torch_dtypes.py",
    "chars": 2263,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/torch_functional.py",
    "chars": 19857,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/tracking.py",
    "chars": 6115,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/utils/ulysses.py",
    "chars": 10619,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/version/version",
    "chars": 10,
    "preview": "0.2.0.dev\n"
  },
  {
    "path": "verl/workers/__init__.py",
    "chars": 600,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/actor/__init__.py",
    "chars": 727,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/actor/base.py",
    "chars": 2071,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/actor/dp_actor.py",
    "chars": 17265,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/actor/megatron_actor.py",
    "chars": 18891,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/critic/__init__.py",
    "chars": 732,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/critic/base.py",
    "chars": 1094,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/critic/dp_critic.py",
    "chars": 12427,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/critic/megatron_critic.py",
    "chars": 10542,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/fsdp_workers.py",
    "chars": 58305,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/megatron_workers.py",
    "chars": 38900,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/reward_manager/__init__.py",
    "chars": 672,
    "preview": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "verl/workers/reward_manager/naive.py",
    "chars": 3450,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/reward_manager/prime.py",
    "chars": 6254,
    "preview": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "verl/workers/reward_model/__init__.py",
    "chars": 638,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/reward_model/base.py",
    "chars": 1710,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/reward_model/megatron/__init__.py",
    "chars": 647,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/reward_model/megatron/reward_model.py",
    "chars": 12739,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/rollout/__init__.py",
    "chars": 753,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/rollout/base.py",
    "chars": 1141,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/rollout/hf_rollout.py",
    "chars": 5867,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/rollout/naive/__init__.py",
    "chars": 641,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/rollout/naive/naive_rollout.py",
    "chars": 4937,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  },
  {
    "path": "verl/workers/rollout/tokenizer.py",
    "chars": 5747,
    "preview": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\")"
  }
]

// ... and 21 more files (download for full content)

About this extraction

This page contains the full source code of the GAIR-NLP/ToRL GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 221 files (1.6 MB), approximately 380.6k tokens, and a symbol index with 1618 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!