Full Code of InternLM/OREAL for AI

main 2859cc092ddf cached
18 files
144.0 KB
35.1k tokens
81 symbols
1 requests
Download .txt
Repository: InternLM/OREAL
Branch: main
Commit: 2859cc092ddf
Files: 18
Total size: 144.0 KB

Directory structure:
gitextract_3rj9la5a/

├── .gitignore
├── LICENSE
├── README.md
├── oreal/
│   ├── configs/
│   │   ├── oreal_w_tokenrm_DSR1-Distll-Qwen-7B_seqlen16k.py
│   │   ├── oreal_w_tokenrm_OREAL-32B-SFT_seqlen16k.py
│   │   ├── oreal_w_tokenrm_OREAL-7B-SFT_seqlen16k.py
│   │   └── oreal_wo_tokenrm_OREAL-7B-SFT_seqlen16k.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── prompt.py
│   │   └── trajectory.py
│   ├── judgers/
│   │   ├── __init__.py
│   │   ├── base_judger.py
│   │   ├── math_judger.py
│   │   ├── router.py
│   │   └── utils.py
│   └── utils.py
├── requirements.text
└── train_oreal.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.

src/

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/*/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# custom
data/
data
.vscode
.idea
.DS_Store
*.pkl
*.pkl.json
*.log.json
work_dirs/

# Pytorch
*.pth
*.py~
*.sh~

# srun
*.out
batchscript-*


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# OREAL: Exploring the Limit of Outcome Reward for Learning Mathematical Reasoning


[![license](https://img.shields.io/github/license/InternLM/opencompass.svg)](./LICENSE)
[![arXiv](https://img.shields.io/badge/arXiv-2502.06781-b31b1b.svg)](https://arxiv.org/abs/2502.06781)
[![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-OREAL-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/collections/internlm/oreal-67aaccf5a8192c1ba3cff018)


## ✨ Introduction

![main_fig](./figures/main_fig.jpg)

Reasoning abilities, especially those for solving complex math problems, are crucial components of general intelligence.
Recent advances by proprietary companies, such as o-series models of OpenAI, have made remarkable progress on reasoning tasks. However, the complete technical details remain unrevealed, and the techniques that are believed certainly to be adopted are only reinforcement learning (RL) and the long chain of thoughts.

We proposes a new RL framework, termed OREAL, to pursue the performance limit that can be achieved through **O**utcome **RE**w**A**rd-based reinforcement **L**earning for mathematical reasoning tasks, where only binary outcome rewards are easily accessible.

+ We theoretically prove that behavior cloning on positive trajectories from best-of-N (BoN) sampling is sufficient to learn the KL-regularized optimal policy in binary feedback environments.
+ This formulation further implies that the rewards of negative samples should be reshaped to ensure the gradient consistency between positive and negative samples.
+ To alleviate the long-existing difficulties brought by sparse rewards in RL, which are even exacerbated by the partial correctness of the long chain of thought for reasoning tasks, we further apply a token-level reward model to sample important tokens in reasoning trajectories for learning.

The OREAL implementation pseudocode is as follows:

![algo](./figures/algo.png)


## 📃 Key Results

With OREAL, for the first time, a 7B model can obtain 94.0 pass@1 accuracy on MATH-500 through RL, being on par with 32B models. OREAL-32B also surpasses previous 32B models trained by distillation with 95.0 pass@1 accuracy on MATH-500.

![main_table](./figures/main_table.png)

## 🤗 HuggingFace

### Model

Our OREAL models are available on Hugging Face 🤗:

| Model    | Huggingface Repo |
|----------|------------------|
| OREAL-DeepSeek-R1-Distill-Qwen-7B  | [Model Link](https://huggingface.co/internlm/OREAL-DeepSeek-R1-Distill-Qwen-7B) |
| OREAL-7B  | [Model Link](https://huggingface.co/internlm/OREAL-7B)  |
| OREAL-32B  | [Model Link](https://huggingface.co/internlm/OREAL-32B)  |

We also release the models of SFT version. You can construct your own RL pipeline on them:)

| Model    | Huggingface Repo |
|----------|------------------|
| OREAL-7B-SFT  | [Model Link](https://huggingface.co/internlm/OREAL-7B-SFT)  |
| OREAL-32B-SFT  | [Model Link](https://huggingface.co/internlm/OREAL-32B-SFT)  |

### Data

We release the prompts utilzed in our RL training phase.

| Dataset    | Huggingface Repo |
|----------|------------------|
| RL Prompts  | [Model Link](https://huggingface.co/datasets/internlm/OREAL-RL-Prompts)  |

## 🚄 Training Tutorial

### 1. Install Dependencies

OREAL utilizes [XTuner](https://github.com/InternLM/xtuner/tree/main) as the training engine. 

```bash
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
pip install flash-attn --no-build-isolation
pip install -r requirements.txt
```

### 2. Prepare Data (Optional)

The training data can be found at [HERE](https://huggingface.co/datasets/internlm/OREAL-RL-Prompts). The training script will automatically download the data from huggingface.

### 3. Start LLM Verifier Service

OREAL requires a language model as a verifier to evaluate the correctness of the generated solutions along with a rule based verificy function (see the [source code](oreal/judgers/math_judger.py)). We use Qwen2.5-72B-Instruct as the verifier in our experiments. You can start the verifier service with [lmdeploy](https://github.com/InternLM/lmdeploy) by running the following command:

```bash
lmdeploy serve api_server Qwen/Qwen2.5-72B-Instruct --tp 4 --chat-template qwen --log-level INFO --server-port 10003
```

Or you can use any other inference engine such as [sglang](https://github.com/sgl-project/sglang) or [vllm](https://github.com/vllm-project/vllm) or [ollama](https://ollama.com/). Just make sure the verifier service can be reached by OpenAI-compatible API.

Fill in the verifier service address in the [config file](./oreal/configs) before training.

```python
judgers_config = dict(
    math_judger=dict(  # math judger related settings
        hosts=["x.x.x.x:xxxx", "x.x.x.x:xxxx"],  # verifier service addresses
        stop_word=stop_word,
        thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],
        num_processes=8,
        concurrency_per_proc=(8, 8),
    )
)
```

### 4. Train OREAL

**OREAL-7B**

7B requires 32 GPUs to train. You can use the following command to train the model with [OREAL-7B-SFT](https://huggingface.co/internlm/OREAL-7B-SFT) as the initial policy:

```bash
torchrun --nnodes 4 --nproc_per_node 8 --master_addr $MASTER_ADDR --node_rank $RANK --master_port $MASTER_PORT train_oreal.py oreal/configs/oreal_w_tokenrm_OREAL-7B-SFT_seqlen16k.py --total_steps 90 --work_dir ./work_dir/oreal_w_tokenrm_OREAL-7B-SFT_seqlen16k
```

It takes about 9 hours to train the model 90 steps with 32xA100.

**OREAL-32B**

32B requires 128 GPUs to train. You can use the following command to train the model with [OREAL-32B-SFT](https://huggingface.co/internlm/OREAL-32B-SFT) as the initial policy:

```bash
torchrun --nnodes 16 --nproc_per_node 8 --master_addr $MASTER_ADDR --node_rank $RANK --master_port $MASTER_PORT train_oreal.py oreal/configs/oreal_w_tokenrm_OREAL-32B-SFT_seqlen16k.py --total_steps 90 --work_dir ./work_dir/oreal_w_tokenrm_OREAL-32B-SFT_seqlen16k
```

More detailed training settings can be found in the [oreal/configs](./oreal/configs) folder.

**Note**:

+ The best checkpoint may not be the last one. Consider evaluating during training and early stopping when the performance is saturated.


## 🖊️ Citation

```
@article{lyu2025exploring,
  title={Exploring the Limit of Outcome Reward for Learning Mathematical Reasoning},
  author={Lyu, Chengqi and Gao, Songyang and Gu, Yuzhe and Zhang, Wenwei and Gao, Jianfei and Liu, Kuikun and Wang, Ziyi and Li, Shuaibin and Zhao, Qian and Huang, Haian and others},
  journal={arXiv preprint arXiv:2502.06781},
  year={2025}
}
```

## 💳 License

This project is released under the Apache 2.0 [license](./LICENSE).


================================================
FILE: oreal/configs/oreal_w_tokenrm_DSR1-Distll-Qwen-7B_seqlen16k.py
================================================
# Model Related Settings
actor = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
reference = actor
token_level_rm = actor

# Tokenizer related settings
# jinja2 template for hf tokenizer
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}"
stop_word = "<|end▁of▁sentence|>"

dtype = "auto"
selective_recompute = 1.0
cpu_offload = False
cuda_graph = True
tp_size = 4
sp_size = 1

# Dataset Related Settings
data_difficulty_balance_cfg = [
    # pass rate range, repeat times
    ((0.0, 0.2), 6),
    ((0.2, 0.4), 4),
    ((0.4, 0.6), 4),
    ((0.6, 0.8), 2),
]
datasets = "internlm/OREAL-RL-Prompts"
num_workers = 0

# Generate Related Settings
gen_global_batch = 1024
gen_max_new = 14000
gen_max_length = 16384
gen_top_k = 0  # set to 0 means not use topk sampling
gen_top_p = 0.9
temperature = 1.0
gen_do_sample = True
max_prefill_batch = 16
prompt_repeat_k = 16  # sample k times for each prompt

# Optimizer Related Settings
rl_global_batch = gen_global_batch
rl_mirco_batch = 2
filter_trajectory = True  # sample one correct and one incorrect trajectory for each prompt
warmup_steps = 10
total_steps = 90
actor_freeze_steps = 10  # freeze actor and only update token level reward model for the first 10 steps
actor_lr = 5e-7
actor_min_lr = 1e-7
token_level_rm_lr = 2e-6
token_level_rm_lr_min = 4e-7
wd = 0.01  # weight decay
max_grad_norm = 1  # gradient clipping

# importance sampling setting with token level reward model
threshold_rescale = True
correct_threshold = 0.5
incorrect_threshold = 0.5
# topk_rescale = True
# correct_topk_ratio = 0.25
# incorrect_topk_ratio = 0.25

reward_shaping_type = "rloo"
loss_type = "per_token"
positive_loss_factor = 1.0
negative_loss_factor = 0.5
pos_mult_adv = True
kl_coef = 0.01  # KL coefficient

# General Settings
work_dir = "work_dirs"  # directory to save logs and checkpoints
checkpoint_interval = 10  # interval to save checkpoint, <1 means save by proportion, >=1 means save by steps
log_interval = 1  # interval steps for logging
seed = 0  # random seed
debug = False  # set log level to DEBUG

# judger related settings
judgers_config = dict(
    math_judger=dict(  # math judger related settings
        hosts=[
            "YOUR_JUDGER_HOST1:PORT",
            "YOUR_JUDGER_HOST2:PORT",
        ],
        stop_word=stop_word,
        thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],
        num_processes=8,
        concurrency_per_proc=(8, 8),
    )
)
data_judger_mapping = dict(math=["math_judger"])


================================================
FILE: oreal/configs/oreal_w_tokenrm_OREAL-32B-SFT_seqlen16k.py
================================================
# Model Related Settings
actor = 'internlm/OREAL-32B-SFT'
reference = actor
token_level_rm = actor

# Tokenizer related settings
# jinja2 template for hf tokenizer
chat_template = "{% set sys_prompt = \"You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:\\n\\n## Deep Understanding\\nTake time to fully comprehend the problem before attempting a solution. Consider:\\n- What is the real question being asked?\\n- What are the given conditions and what do they tell us?\\n- Are there any special restrictions or assumptions?\\n- Which information is crucial and which is supplementary?\\n\\n## Multi-angle Analysis\\nBefore solving, conduct thorough analysis:\\n- What mathematical concepts and properties are involved?\\n- Can you recall similar classic problems or solution methods?\\n- Would diagrams or tables help visualize the problem?\\n- Are there special cases that need separate consideration?\\n\\n## Systematic Thinking\\nPlan your solution path:\\n- Propose multiple possible approaches\\n- Analyze the feasibility and merits of each method\\n- Choose the most appropriate method and explain why\\n- Break complex problems into smaller, manageable steps\\n\\n## Rigorous Proof\\nDuring the solution process:\\n- Provide solid justification for each step\\n- Include detailed proofs for key conclusions\\n- Pay attention to logical connections\\n- Be vigilant about potential oversights\\n\\n## Repeated Verification\\nAfter completing your solution:\\n- Verify your results satisfy all conditions\\n- Check for overlooked special cases\\n- Consider if the solution can be optimized or simplified\\n- Review your reasoning process\\n\\nRemember:\\n1. Take time to think thoroughly rather than rushing to an answer\\n2. Rigorously prove each key conclusion\\n3. Keep an open mind and try different approaches\\n4. Summarize valuable problem-solving methods\\n5. Maintain healthy skepticism and verify multiple times\\n\\nYour response should reflect deep mathematical understanding and precise logical thinking, making your solution path and reasoning clear to others.\\n\\nWhen you're ready, present your complete solution with:\\n- Clear problem understanding\\n- Detailed solution process\\n- Key insights\\n- Thorough verification\\n\\nFocus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.\" %}{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- messages[0]['content'] }}\n    {%- else %}\n        {{- sys_prompt }}\n    {%- endif %}\n    {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n    {%- else %}\n        {{- '<|im_start|>system\\n' ~ sys_prompt ~ '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n        {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {{- '<|im_start|>' + message.role }}\n        {%- if message.content %}\n            {{- '\\n' + message.content }}\n        {%- endif %}\n        {%- for tool_call in message.tool_calls %}\n            {%- if tool_call.function is defined %}\n                {%- set tool_call = tool_call.function %}\n            {%- endif %}\n            {{- '\\n<tool_call>\\n{\"name\": \"' }}\n            {{- tool_call.name }}\n            {{- '\", \"arguments\": ' }}\n            {{- tool_call.arguments | tojson }}\n            {{- '}\\n</tool_call>' }}\n        {%- endfor %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- message.content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
stop_word = "<|im_end|>"

dtype = "auto"
selective_recompute = 1.0
cpu_offload = False
cuda_graph = True
tp_size = 8
sp_size = 1

# Dataset Related Settings
data_difficulty_balance_cfg = [
    # pass rate range, repeat times
    ((0.0, 0.2), 6),
    ((0.2, 0.4), 4),
    ((0.4, 0.6), 4),
    ((0.6, 0.8), 2),
]
datasets = "internlm/OREAL-RL-Prompts"
num_workers = 0

# Generate Related Settings
gen_global_batch = 1024
gen_max_new = 14000
gen_max_length = 16384
gen_top_k = 0  # set to 0 means not use topk sampling
gen_top_p = 0.9
temperature = 1.0
gen_do_sample = True
max_prefill_batch = 16
prompt_repeat_k = 16  # sample k times for each prompt

# Optimizer Related Settings
rl_global_batch = gen_global_batch
rl_mirco_batch = 2
filter_trajectory = True  # sample one correct and one incorrect trajectory for each prompt
warmup_steps = 10
total_steps = 90
actor_freeze_steps = 10  # freeze actor and only update token level reward model for the first 10 steps
actor_lr = 5e-7
actor_min_lr = 1e-7
token_level_rm_lr = 2e-6
token_level_rm_lr_min = 4e-7
wd = 0.01  # weight decay
max_grad_norm = 1  # gradient clipping

# importance sampling setting with token level reward model
threshold_rescale = True
correct_threshold = 0.5
incorrect_threshold = 0.5
# topk_rescale = True
# correct_topk_ratio = 0.25
# incorrect_topk_ratio = 0.25

reward_shaping_type = "rloo"
loss_type = "per_token"
positive_loss_factor = 1.0
negative_loss_factor = 0.5
pos_mult_adv = True
kl_coef = 0.01  # KL coefficient

# General Settings
work_dir = "work_dirs"  # directory to save logs and checkpoints
checkpoint_interval = 10  # interval to save checkpoint, <1 means save by proportion, >=1 means save by steps
log_interval = 1  # interval steps for logging
seed = 0  # random seed
debug = False  # set log level to DEBUG

# judger related settings
judgers_config = dict(
    math_judger=dict(  # math judger related settings
        hosts=[
            "YOUR_JUDGER_HOST1:PORT",
            "YOUR_JUDGER_HOST2:PORT",
        ],
        stop_word=stop_word,
        thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],
        num_processes=8,
        concurrency_per_proc=(8, 8),
    )
)
data_judger_mapping = dict(math=["math_judger"])


================================================
FILE: oreal/configs/oreal_w_tokenrm_OREAL-7B-SFT_seqlen16k.py
================================================
# Model Related Settings
actor = "internlm/OREAL-7B-SFT"
reference = actor
token_level_rm = actor

# Tokenizer related settings
# jinja2 template for hf tokenizer
chat_template = "{% set sys_prompt = \"You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:\\n\\n## Deep Understanding\\nTake time to fully comprehend the problem before attempting a solution. Consider:\\n- What is the real question being asked?\\n- What are the given conditions and what do they tell us?\\n- Are there any special restrictions or assumptions?\\n- Which information is crucial and which is supplementary?\\n\\n## Multi-angle Analysis\\nBefore solving, conduct thorough analysis:\\n- What mathematical concepts and properties are involved?\\n- Can you recall similar classic problems or solution methods?\\n- Would diagrams or tables help visualize the problem?\\n- Are there special cases that need separate consideration?\\n\\n## Systematic Thinking\\nPlan your solution path:\\n- Propose multiple possible approaches\\n- Analyze the feasibility and merits of each method\\n- Choose the most appropriate method and explain why\\n- Break complex problems into smaller, manageable steps\\n\\n## Rigorous Proof\\nDuring the solution process:\\n- Provide solid justification for each step\\n- Include detailed proofs for key conclusions\\n- Pay attention to logical connections\\n- Be vigilant about potential oversights\\n\\n## Repeated Verification\\nAfter completing your solution:\\n- Verify your results satisfy all conditions\\n- Check for overlooked special cases\\n- Consider if the solution can be optimized or simplified\\n- Review your reasoning process\\n\\nRemember:\\n1. Take time to think thoroughly rather than rushing to an answer\\n2. Rigorously prove each key conclusion\\n3. Keep an open mind and try different approaches\\n4. Summarize valuable problem-solving methods\\n5. Maintain healthy skepticism and verify multiple times\\n\\nYour response should reflect deep mathematical understanding and precise logical thinking, making your solution path and reasoning clear to others.\\n\\nWhen you're ready, present your complete solution with:\\n- Clear problem understanding\\n- Detailed solution process\\n- Key insights\\n- Thorough verification\\n\\nFocus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.\" %}{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- messages[0]['content'] }}\n    {%- else %}\n        {{- sys_prompt }}\n    {%- endif %}\n    {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n    {%- else %}\n        {{- '<|im_start|>system\\n' ~ sys_prompt ~ '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n        {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {{- '<|im_start|>' + message.role }}\n        {%- if message.content %}\n            {{- '\\n' + message.content }}\n        {%- endif %}\n        {%- for tool_call in message.tool_calls %}\n            {%- if tool_call.function is defined %}\n                {%- set tool_call = tool_call.function %}\n            {%- endif %}\n            {{- '\\n<tool_call>\\n{\"name\": \"' }}\n            {{- tool_call.name }}\n            {{- '\", \"arguments\": ' }}\n            {{- tool_call.arguments | tojson }}\n            {{- '}\\n</tool_call>' }}\n        {%- endfor %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- message.content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
stop_word = "<|im_end|>"

dtype = "auto"
selective_recompute = 1.0
cpu_offload = False
cuda_graph = True
tp_size = 4
sp_size = 1

# Dataset Related Settings
data_difficulty_balance_cfg = [
    # pass rate range, repeat times
    ((0.0, 0.2), 6),
    ((0.2, 0.4), 4),
    ((0.4, 0.6), 4),
    ((0.6, 0.8), 2),
]
datasets = "internlm/OREAL-RL-Prompts"
num_workers = 0

# Generate Related Settings
gen_global_batch = 1024
gen_max_new = 14000
gen_max_length = 16384
gen_top_k = 0  # set to 0 means not use topk sampling
gen_top_p = 0.9
temperature = 1.0
gen_do_sample = True
max_prefill_batch = 16
prompt_repeat_k = 16  # sample k times for each prompt

# Optimizer Related Settings
rl_global_batch = gen_global_batch
rl_mirco_batch = 2
filter_trajectory = True  # sample one correct and one incorrect trajectory for each prompt
warmup_steps = 10
total_steps = 90
actor_freeze_steps = 10  # freeze actor and only update token level reward model for the first 10 steps
actor_lr = 5e-7
actor_min_lr = 1e-7
token_level_rm_lr = 2e-6
token_level_rm_lr_min = 4e-7
wd = 0.01  # weight decay
max_grad_norm = 1  # gradient clipping

# importance sampling setting with token level reward model
threshold_rescale = True
correct_threshold = 0.5
incorrect_threshold = 0.5
# topk_rescale = True
# correct_topk_ratio = 0.25
# incorrect_topk_ratio = 0.25

reward_shaping_type = "rloo"
loss_type = "per_token"
positive_loss_factor = 1.0
negative_loss_factor = 0.5
pos_mult_adv = True
kl_coef = 0.01  # KL coefficient

# General Settings
work_dir = "work_dirs"  # directory to save logs and checkpoints
checkpoint_interval = 10  # interval to save checkpoint, <1 means save by proportion, >=1 means save by steps
log_interval = 1  # interval steps for logging
seed = 0  # random seed
debug = False  # set log level to DEBUG

# judger related settings
judgers_config = dict(
    math_judger=dict(  # math judger related settings
        hosts=[
            "YOUR_JUDGER_HOST1:PORT",
            "YOUR_JUDGER_HOST2:PORT",
        ],
        stop_word=stop_word,
        thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],
        num_processes=8,
        concurrency_per_proc=(8, 8),
    )
)
data_judger_mapping = dict(math=["math_judger"])


================================================
FILE: oreal/configs/oreal_wo_tokenrm_OREAL-7B-SFT_seqlen16k.py
================================================
# Model Related Settings
actor = "internlm/OREAL-7B-SFT"
reference = actor
token_level_rm = None

# Tokenizer related settings
# jinja2 template for hf tokenizer
chat_template = "{% set sys_prompt = \"You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:\\n\\n## Deep Understanding\\nTake time to fully comprehend the problem before attempting a solution. Consider:\\n- What is the real question being asked?\\n- What are the given conditions and what do they tell us?\\n- Are there any special restrictions or assumptions?\\n- Which information is crucial and which is supplementary?\\n\\n## Multi-angle Analysis\\nBefore solving, conduct thorough analysis:\\n- What mathematical concepts and properties are involved?\\n- Can you recall similar classic problems or solution methods?\\n- Would diagrams or tables help visualize the problem?\\n- Are there special cases that need separate consideration?\\n\\n## Systematic Thinking\\nPlan your solution path:\\n- Propose multiple possible approaches\\n- Analyze the feasibility and merits of each method\\n- Choose the most appropriate method and explain why\\n- Break complex problems into smaller, manageable steps\\n\\n## Rigorous Proof\\nDuring the solution process:\\n- Provide solid justification for each step\\n- Include detailed proofs for key conclusions\\n- Pay attention to logical connections\\n- Be vigilant about potential oversights\\n\\n## Repeated Verification\\nAfter completing your solution:\\n- Verify your results satisfy all conditions\\n- Check for overlooked special cases\\n- Consider if the solution can be optimized or simplified\\n- Review your reasoning process\\n\\nRemember:\\n1. Take time to think thoroughly rather than rushing to an answer\\n2. Rigorously prove each key conclusion\\n3. Keep an open mind and try different approaches\\n4. Summarize valuable problem-solving methods\\n5. Maintain healthy skepticism and verify multiple times\\n\\nYour response should reflect deep mathematical understanding and precise logical thinking, making your solution path and reasoning clear to others.\\n\\nWhen you're ready, present your complete solution with:\\n- Clear problem understanding\\n- Detailed solution process\\n- Key insights\\n- Thorough verification\\n\\nFocus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.\" %}{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- messages[0]['content'] }}\n    {%- else %}\n        {{- sys_prompt }}\n    {%- endif %}\n    {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n    {%- else %}\n        {{- '<|im_start|>system\\n' ~ sys_prompt ~ '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n        {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {{- '<|im_start|>' + message.role }}\n        {%- if message.content %}\n            {{- '\\n' + message.content }}\n        {%- endif %}\n        {%- for tool_call in message.tool_calls %}\n            {%- if tool_call.function is defined %}\n                {%- set tool_call = tool_call.function %}\n            {%- endif %}\n            {{- '\\n<tool_call>\\n{\"name\": \"' }}\n            {{- tool_call.name }}\n            {{- '\", \"arguments\": ' }}\n            {{- tool_call.arguments | tojson }}\n            {{- '}\\n</tool_call>' }}\n        {%- endfor %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- message.content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
stop_word = "<|im_end|>"

dtype = "auto"
selective_recompute = 1.0
cpu_offload = False
cuda_graph = True
tp_size = 4
sp_size = 1

# Dataset Related Settings
data_difficulty_balance_cfg = [
    # pass rate range, repeat times
    ((0.0, 0.2), 6),
    ((0.2, 0.4), 4),
    ((0.4, 0.6), 4),
    ((0.6, 0.8), 2),
]
datasets = "internlm/OREAL-RL-Prompts"
num_workers = 0

# Generate Related Settings
gen_global_batch = 1024
gen_max_new = 14000
gen_max_length = 16384
gen_top_k = 0  # set to 0 means not use topk sampling
gen_top_p = 0.9
temperature = 1.0
gen_do_sample = True
max_prefill_batch = 16
prompt_repeat_k = 16  # sample k times for each prompt

# Optimizer Related Settings
rl_global_batch = gen_global_batch
rl_mirco_batch = 2
filter_trajectory = False
warmup_steps = 10
total_steps = 90
actor_freeze_steps = 0
actor_lr = 5e-7
actor_min_lr = 1e-7
token_level_rm_lr = 2e-6
token_level_rm_lr_min = 4e-7
wd = 0.01  # weight decay
max_grad_norm = 1  # gradient clipping

# importance sampling setting with token level reward model
threshold_rescale = True
correct_threshold = 0.5
incorrect_threshold = 0.5
# topk_rescale = True
# correct_topk_ratio = 0.25
# incorrect_topk_ratio = 0.25

reward_shaping_type = "rloo"
loss_type = "per_token"
positive_loss_factor = 1.0
negative_loss_factor = 0.5
pos_mult_adv = True
kl_coef = 0.01  # KL coefficient

# General Settings
work_dir = "work_dirs"  # directory to save logs and checkpoints
checkpoint_interval = 10  # interval to save checkpoint, <1 means save by proportion, >=1 means save by steps
log_interval = 1  # interval steps for logging
seed = 0  # random seed
debug = False  # set log level to DEBUG

# judger related settings
judgers_config = dict(
    math_judger=dict(  # math judger related settings
        hosts=[
            "YOUR_JUDGER_HOST1:PORT",
            "YOUR_JUDGER_HOST2:PORT",
        ],
        stop_word=stop_word,
        thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],
        num_processes=8,
        concurrency_per_proc=(8, 8),
    )
)
data_judger_mapping = dict(math=["math_judger"])


================================================
FILE: oreal/datasets/__init__.py
================================================
# Copyright (c) InternLM. All rights reserved.
from .prompt import OrealPromptDataset, PromptCollator
from .trajectory import (
    InferDataset,
    TrajectoryCollator,
    TrajectoryDataset,
    TrajectoryDatasetWithFilter,
)

__all__ = [
    "OrealPromptDataset",
    "PromptCollator",
    "InferDataset",
    "TrajectoryDataset",
    "TrajectoryDatasetWithFilter",
    "TrajectoryCollator",
]


================================================
FILE: oreal/datasets/prompt.py
================================================
# Copyright (c) InternLM. All rights reserved.
import json

import torch
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from xtuner._lite import get_logger

logger = get_logger()


def load_hf_datasets(repo, split="train"):
    dataset = load_dataset(repo, split=split)
    converted_ds = []
    for sample in dataset:
        converted_ds.append(
            {
                "pass_rate": sample["pass_rate"],
                "message_data": [{"role": "user", "content": sample["question"]}],
                "metadata": {
                    "data_source": "math",  # for the router to know which judger to use
                    "gold_answer": sample["gold_answer"],
                },
            }
        )
    logger.info(f"Loaded {len(converted_ds)} samples from {repo}")
    return converted_ds


def load_jsonl_datasets(file_path):
    with open(file_path, "r") as f:
        lines = f.readlines()
    datasets = []
    for line in lines:
        sample = json.loads(line)
        datasets.append(
            {
                "pass_rate": sample["pass_rate"],
                "message_data": [{"role": "user", "content": sample["question"]}],
                "metadata": {
                    "data_source": "math",  # for the router to know which judger to use
                    "gold_answer": sample["gold_answer"],
                },
            }
        )
    logger.info(f"Loaded {len(datasets)} samples from {file_path}")
    return datasets


def balance_difficulty_with_cfg(dataset, difficulty_balance_cfg):
    balanced_dataset = []
    for sample in dataset:
        pass_rate = sample["pass_rate"]
        for (low, high), repeat in difficulty_balance_cfg:
            if low <= pass_rate < high:
                balanced_dataset.extend([sample] * repeat)
                break
    logger.info(
        f"After difficulty balancing, the dataset size is {len(balanced_dataset)}"
    )
    return balanced_dataset


class OrealPromptDataset(Dataset):
    def __init__(self, path, tokenizer, difficulty_balance_cfg=None):
        if isinstance(path, str):
            path = [path]
        dataset = []
        for p in path:
            if p.endswith(".jsonl"):
                dataset.extend(load_jsonl_datasets(p))
            else:
                dataset.extend(load_hf_datasets(p))
        if difficulty_balance_cfg:
            dataset = balance_difficulty_with_cfg(dataset, difficulty_balance_cfg)
        self.dataset = dataset
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        input_ids = self.tokenizer.apply_chat_template(
            sample["message_data"], add_generation_prompt=True
        )
        sample["input_ids"] = input_ids
        sample["labels"] = input_ids
        sample["num_tokens"] = len(input_ids)
        return sample


class PromptCollator:

    def __init__(self, pad_token_id=0, ignore_id=-100, pack_batch=False):
        self.pack_batch = pack_batch
        self.pad_token_id = pad_token_id
        self.ignore_id = ignore_id

    def __call__(self, instances):

        _instances = []
        for ins in instances:
            if isinstance(ins, list):
                _instances.extend(ins)
            else:
                _instances.append(ins)

        instances = _instances

        input_ids = []
        labels = []
        num_tokens = []
        metadatas = []
        message_datas = []

        for data in instances:

            input_ids.append(torch.LongTensor(data["input_ids"]))
            labels.append(torch.LongTensor(data["labels"]))
            metadatas.append(data["metadata"])
            message_datas.append(data["message_data"])

            if isinstance(data["num_tokens"], int):
                num_tokens.append(data["num_tokens"])
            else:
                num_tokens.extend(data["num_tokens"])

        attention_mask = [torch.ones_like(ids) for ids in input_ids]
        num_tokens = torch.IntTensor(num_tokens)

        if len(instances) > 1 and self.pack_batch:

            input_ids = torch.cat(input_ids, dim=0).unsqueeze(0)
            labels = torch.cat(labels, dim=0).unsqueeze(0)
            attention_mask = torch.cat(attention_mask, dim=0).unsqueeze(0)

        elif len(instances) > 1 and not self.pack_batch:

            input_ids = pad_sequence(
                input_ids, batch_first=True, padding_value=self.pad_token_id
            )
            labels = pad_sequence(
                labels, batch_first=True, padding_value=self.ignore_id
            )
            attention_mask = pad_sequence(
                attention_mask, batch_first=True, padding_value=0
            )
        else:
            input_ids = torch.stack(input_ids)
            labels = torch.stack(labels)
            attention_mask = torch.stack(attention_mask)

        if input_ids.shape != labels.shape:
            logger.error(f"[instances] {instances}")
            logger.error(f"[num_tokens] {num_tokens}")
            logger.error(f"[input_ids] {input_ids}")
            logger.error(f"[labels] {labels}")
            raise RuntimeError(
                "The shape of input_ids and labels must be "
                f"equal, but  found {input_ids.shape} and "
                f"{labels.shape}."
            )
        data_dict = {
            "input_ids": input_ids,
            "labels": labels,
            "num_tokens": num_tokens,
            "attention_mask": attention_mask.bool(),
            "metadata": metadatas,
            "message_data": message_datas,
        }

        return data_dict


if __name__ == "__main__":
    difficulty_balance_cfg = [
        # pass rate range, repeat times
        ((0.0, 0.2), 6),
        ((0.2, 0.4), 4),
        ((0.4, 0.6), 4),
        ((0.6, 0.8), 2),
    ]
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained("oreal/OREAL-7B")
    dataset = OrealPromptDataset(
        "internlm/OREAL-RL-Prompts", tokenizer, difficulty_balance_cfg
    )
    print(len(dataset))
    print(dataset[0])
    print(tokenizer.decode(dataset[0]["input_ids"]))


================================================
FILE: oreal/datasets/trajectory.py
================================================
# Copyright (c) InternLM. All rights reserved.
import json
import random

import numpy as np
import torch
from xtuner._lite import get_logger
from xtuner._lite.algorithms.sft.dataset import SftCollator

logger = get_logger()


class InferDataset(torch.utils.data.Dataset):

    def __init__(self, prompts_input_ids, responses_ids, message_data, metadata):
        super().__init__()

        assert (
            len(prompts_input_ids)
            == len(responses_ids)
            == len(message_data)
            == len(metadata)
        ), f"The length of prompts_input_ids, responses_ids, message_data, metadata should be the same, but got {len(prompts_input_ids)}, {len(responses_ids)}, {len(message_data)}, {len(metadata)}"
        self.prompts_input_ids = prompts_input_ids
        self.responses_ids = responses_ids
        self.message_data = message_data
        self.metadata = metadata

    def __len__(self):
        return len(self.prompts_input_ids)

    def __getitem__(self, item):

        prompt_input_ids = self.prompts_input_ids[item]
        response_ids = self.responses_ids[item]
        num_prefill_tokens = len(prompt_input_ids)

        input_ids = prompt_input_ids + response_ids
        labels = [-100] * (num_prefill_tokens - 1) + response_ids + [-100]

        return {
            "input_ids": input_ids,
            "labels": labels,
            "num_tokens": len(input_ids),
            "message_data": self.message_data[item],
            "metadata": self.metadata[item],
        }


class TrajectoryDataset(torch.utils.data.Dataset):

    def __init__(self):
        super().__init__()
        self._num_action_tokens = 0
        self._num_total_tokens = 0
        self._trajectories = []

    @property
    def num_action_tokens(self):
        return self._num_action_tokens.item()

    @property
    def num_total_tokens(self):
        return self._num_total_tokens

    def update(self, trajectories):
        num_total_tokens = 0
        num_action_tokens = 0
        for data in trajectories:
            labels = np.array(data["labels"])
            num_total_tokens += labels.size
            num_action_tokens += (labels >= 0).sum()

        self._num_action_tokens = num_action_tokens
        self._num_total_tokens = num_total_tokens

        self._trajectories = trajectories

    def dump_jsonl(self, path, tokenizer, debug=False):

        with open(path, "w", encoding="utf8") as f:
            for data in self._trajectories:
                json_line = {
                    "sequence": (
                        data["sequence_text"]
                        if "sequence_text" in data
                        else tokenizer.decode(data["input_ids"])
                    ),
                    "num_tokens": data["num_tokens"],
                }
                json_line["judger_reward"] = data["judger_reward"]
                json_line["judger_advantage"] = data["judger_advantage"]

                if debug:
                    json_line["input_ids"] = data["input_ids"]
                    json_line["labels"] = data["labels"]

                json_str = json.dumps(json_line, ensure_ascii=False)
                f.write(json_str + "\n")

    def dump_log(self, path, tokenizer, debug=False):

        with open(path, "w", encoding="utf8") as f:
            for data in self._trajectories:
                log_string = f"[sequence]:\n{data['sequence_text'] if 'sequence_text' in data else tokenizer.decode(data['input_ids'])}\n\n"
                log_string += f"[num_tokens]: {data['num_tokens']}\n"
                log_string += f"[judger_reward]: {data['judger_reward']}\n"
                log_string += f"[judger_advantage]: {data['judger_advantage']}\n"
                f.write(log_string + "\n\n=======================\n")

    def __len__(self):
        return len(self._trajectories)

    def __getitem__(self, item):

        return self._trajectories[item]


class TrajectoryDatasetWithFilter(TrajectoryDataset):
    def __init__(self, repeat_k=1, only_keep_1_pair=True):
        super().__init__()
        self.repeat_k = repeat_k
        self.only_keep_1_pair = only_keep_1_pair

    def update(self, trajectories):
        # split trajectories into k groups: (a, a, b, b, c, c) -> [(a, a), (b, b), (c, c)]
        groups = [
            trajectories[i : i + self.repeat_k]
            for i in range(0, len(trajectories), self.repeat_k)
        ]
        keeped_trajectories = []
        for group in groups:
            correctness = [1 if data["judger_reward"] == 1 else 0 for data in group]
            correct = [data for data in group if data["judger_reward"] == 1]
            incorrect = [data for data in group if data["judger_reward"] != 1]
            pass_rate = sum(correctness) / len(correctness)
            if self.only_keep_1_pair:
                if pass_rate == 1 or pass_rate == 0:
                    continue
                # max keep 1 correct and 1 incorrect
                correct = random.choice(correct)
                incorrect = random.choice(incorrect)
                correct["pass_rate"] = pass_rate
                incorrect["pass_rate"] = pass_rate
                keeped_trajectories.append(correct)
                keeped_trajectories.append(incorrect)
            else:
                if pass_rate == 1 or pass_rate == 0:
                    continue
                for data in group:
                    data["pass_rate"] = pass_rate
                    keeped_trajectories.append(data)

        super().update(keeped_trajectories)


class TrajectoryCollator(SftCollator):

    def __call__(self, instances):

        data = super().__call__(instances)
        data["judger_rewards"] = [item["judger_reward"] for item in instances]
        data["judger_advantages"] = [item["judger_advantage"] for item in instances]
        if "pass_rate" in instances[0]:
            data["pass_rate"] = [item["pass_rate"] for item in instances]
        return data


================================================
FILE: oreal/judgers/__init__.py
================================================
# Copyright (c) InternLM. All rights reserved.
from .base_judger import (
    BaseJudger,
    register_judger,
    registered_judgers,
)
from .math_judger import MathJudger
from .router import InputData, ParallelRouter

__all__ = [
    "register_judger",
    "registered_judgers",
    "BaseJudger",
    "MathJudger",
    "InputData",
    "ParallelRouter",
]


================================================
FILE: oreal/judgers/base_judger.py
================================================
# Copyright (c) InternLM. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
    Dict,
    Generic,
    List,
    Optional,
    Type,
    TypedDict,
    TypeVar,
    Union,
)

T = TypeVar("T")
MessageItem = TypedDict("MessageItem", {"role": str, "content": str})
Reward = Union[float, List[float], None]
MetaData = TypedDict("MetaData", {"data_source": str})


@dataclass
class JudgeStatus(Generic[T]):
    ok: bool = True
    reason: Optional[str] = None
    handle: Optional[T] = None


class BaseJudger(ABC):
    def __init__(self):
        pass

    @abstractmethod
    def on_data_received(
        self,
        prompt_messages: List[MessageItem],
        completion_messages: List[MessageItem],
        metadata: dict,
    ) -> JudgeStatus:
        raise NotImplementedError()

    @abstractmethod
    def on_reward_required(
        self,
        status: JudgeStatus,
        timeout: Optional[float] = None,
    ) -> Reward:
        raise NotImplementedError()


registered_judgers: Dict[str, Type[BaseJudger]] = {}


def register_judger(name: str):
    global registered_judgers

    def wrapper(cls):
        assert name not in registered_judgers, f"{name} already in {registered_judgers}"
        registered_judgers[name] = cls
        return cls

    return wrapper


================================================
FILE: oreal/judgers/math_judger.py
================================================
# Copyright (c) InternLM. All rights reserved.
import random
import re
import time
from typing import List, Optional, Tuple

import requests

from .base_judger import BaseJudger, JudgeStatus, MessageItem, Reward, register_judger
from .utils import extract_answer, math_equal


@register_judger("math_judger")
class MathJudger(BaseJudger):
    verify_prompt = """You are a helpful assistant who evaluates the correctness and quality of models' outputs.
    Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.

    Here are some evaluation criteria:
    1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct.
    2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question.
    3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct.
    4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct.
    5. If the prediction is given with \\boxed{{}}, please ignore the \\boxed{{}} and only judge whether the candidate's answer is consistent with the standard answer.

    Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
    A: CORRECT
    B: INCORRECT
    Just return the letters \"A\" or \"B\", with no text around it.

    Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.


    <Original Question Begin>:
    {question}
    <Original Question End>


    <Gold Target Begin>:
    {gold_answer}
    <Gold Target End>


    <Predicted Answer Begin>:
    {answer}
    <Predicted End>


    Judging the correctness of candidates' answers:"""

    def __init__(
        self,
        hosts: List[str],
        max_retries: int = 1,
        retry_delay: float = 1.0,
        stop_word="<|im_end|>",
        thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],
    ):
        super().__init__()
        self.hosts = hosts
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.stop_word = stop_word
        self.thinking_finish_words = thinking_finish_words

        self.host_ip_idx = random.randint(0, len(hosts) - 1)
        self.model_name = requests.get(
            f"http://{self.hosts[self.host_ip_idx]}/v1/models",
            headers={"Authorization": "Bearer "},
        ).json()["data"][0]["id"]

    def on_data_received(
        self,
        prompt_messages: List[MessageItem],
        completion_messages: List[MessageItem],
        metadata: dict,
    ) -> JudgeStatus:
        question = prompt_messages[-1]["content"]
        response = completion_messages[-1]["content"]
        question_type = metadata.get("question_type", None)
        gold_answer = metadata["gold_answer"]
        if not response.strip().endswith(self.stop_word):
            # If the response does not end with the stop word, it is not a complete response, treat as incorrect
            return JudgeStatus(
                ok=True,
                handle={
                    "question": question,
                    "question_type": question_type,
                    "response": response,
                    "gold_answer": gold_answer,
                    "verify_label": False,
                },
            )

        for thinking_finish_word in self.thinking_finish_words:
            if thinking_finish_word in response:
                response = response.split(thinking_finish_word)[-1]

        response = response.replace(self.stop_word, "")

        # first try to extract and verify with rule, if correct, return
        extracted_answer, verify_label = self._extract_and_verify_with_logic(
            response, gold_answer
        )
        if verify_label is True:
            return JudgeStatus(
                ok=True,
                handle={
                    "question": question,
                    "question_type": question_type,
                    "response": response,
                    "gold_answer": gold_answer,
                    "verify_label": verify_label,
                },
            )

        # then try to evaluate with model
        res_string, verify_label = self._evaluate_answer_with_llm(
            question, question_type, response, gold_answer
        )
        return JudgeStatus(
            ok=True,
            handle={
                "question": question,
                "question_type": question_type,
                "response": response,
                "gold_answer": gold_answer,
                "verify_label": verify_label,
            },
        )

    def on_reward_required(
        self, status: JudgeStatus, timeout: Optional[float] = None
    ) -> Reward:
        if status.handle is None:
            return None
        if status.handle["verify_label"] is not None:
            return 1.0 if status.handle["verify_label"] else -1.0
        return None

    def _evaluate_answer_with_llm(
        self, question: str, question_type: str, answer: str, gold_answer: str
    ) -> Tuple[str, bool]:
        for i in range(self.max_retries):
            host = self.hosts[self.host_ip_idx]
            self.host_ip_idx = (self.host_ip_idx + 1) % len(self.hosts)
            prompt = self.verify_prompt.format(
                "", "", question=question, answer=answer, gold_answer=gold_answer
            )
            try:
                res = requests.post(
                    f"http://{host}/v1/chat/completions",
                    json={
                        "model": self.model_name,
                        "messages": [
                            {
                                "role": "user",
                                "content": prompt,
                            }
                        ],
                        "temperature": 0.0,
                        "top_p": 0.8,
                        "top_k": 20,
                        "repetition_penalty": 1.05,
                        "max_tokens": 100,
                        "stop": ["<|im_end|>", "<|endoftext|>"],
                    },
                )
                res_string = res.json()["choices"][0]["message"]["content"]
                print(f"Evaluate result: {res_string}")
                verify_label = self._verify_from_string(res_string)
                if verify_label is None:
                    raise ValueError(
                        f"Evaluate result is None, judger prediction: {res_string}"
                    )
                return res_string, verify_label

            except Exception as e:
                print(f"Error verifying answer: {e}")
                time.sleep(self.retry_delay)
                continue
        print(f"Failed to verify answer after {self.max_retries} retries.")
        return None, None

    def _verify_from_string(self, verification: str):
        if "A" in verification and "B" not in verification:
            label = True
        elif "B" in verification and "A" not in verification:
            label = False
        else:  # judger model failed to predict A or B
            label = None
        return label

    def _extract_and_verify_with_logic(
        self, response: str, gold_answer: str
    ) -> Tuple[str, bool]:
        extracted_answer = extract_answer(response)
        verify_label = math_equal(extracted_answer, gold_answer)
        return extracted_answer, verify_label


================================================
FILE: oreal/judgers/router.py
================================================
# Copyright (c) InternLM. All rights reserved.
import atexit
import functools
import os
import queue
import time
import traceback
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import Event, Process, Queue, connection
from multiprocessing.synchronize import Event as EventClass
from typing import (
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypedDict,
    TypeVar,
    cast,
)
from uuid import uuid4

import loguru
from typing_extensions import NotRequired

from .base_judger import (
    JudgeStatus,
    MessageItem,
    MetaData,
    Reward,
    registered_judgers,
)


class InputData(TypedDict):
    prompt_messages: List[MessageItem]
    completion_messages: List[MessageItem]
    metadata: NotRequired[MetaData]


T = TypeVar("T")


@dataclass
class GenericTask(Generic[T]):
    token: str
    index: int
    judger: str
    content: T


@dataclass
class SubprocessConfig:
    loguru_handlers: Optional[List[dict]] = None
    worker_init_func: Optional[Callable] = None


class ParallelRouter:
    def __init__(
        self,
        judgers_config: Dict[str, dict],
        data_judger_mapping: Dict[str, Optional[List[str]]],
        logger: Optional["loguru.Logger"] = None,
        subprocess_config: Optional[SubprocessConfig] = None,
    ):
        if logger is not None:
            self.logger = logger
        else:
            import mock

            self.logger = mock.Mock()

        if subprocess_config is not None:
            self.subprocess_config = subprocess_config
        else:
            self.subprocess_config = SubprocessConfig()

        if not (
            isinstance(judgers_config, dict)
            and all(
                isinstance(k, str) and isinstance(v, dict)
                for k, v in judgers_config.items()
            )
        ):
            raise TypeError(
                f"Illegal judgers_config: {judgers_config}\n"
                "Should be Dict[str, dict]"
            )
        if "RM" in judgers_config.keys():
            raise KeyError(
                f"'RM' is a reserved judger keywork for {self.__class__.__name__}, "
                f"please remove it from judgers_config: {judgers_config}"
            )
        self.judgers_config = judgers_config

        data_judger_mapping: Dict[str, List[str]] = {
            k: v or [] for k, v in data_judger_mapping.items()
        }  # change None to empty list []
        if not (
            isinstance(data_judger_mapping, dict)
            and all(
                isinstance(k, str)
                and isinstance(v, (list, tuple, set))
                and all(isinstance(vv, str) for vv in v)
                for k, v in data_judger_mapping.items()
            )
        ):
            raise TypeError(
                f"Illegal data_judger_mapping: {data_judger_mapping}\n"
                "Should be Dict[str, List[str]]"
            )
        self.data_judger_mapping = data_judger_mapping

        avail_judgers = set(self.judgers_config.keys()) | {"RM"}
        _used_judgers: List[str] = []
        for v in data_judger_mapping.values():
            _used_judgers.extend(v)
        used_judgers: set = set(_used_judgers)
        if unused := avail_judgers - used_judgers:
            self.logger.warning(
                "Following judgers are available but not "
                f"used in data mapping: {unused}\n"
                "Please make sure this is intended"
            )
            # remove unused configs
            for judger_name in unused:
                self.judgers_config.pop(judger_name, None)
        if missing := used_judgers - avail_judgers:
            self.logger.warning(
                "Following judgers are configured to be used "
                f"but not built in data mapping: {missing}\n"
                "Please make sure this is intended"
            )
            # remove missing judgers from mapping, to prevent potential errors
            for source in list(self.data_judger_mapping.keys()):
                before = set(self.data_judger_mapping[source])
                self.data_judger_mapping[source] = list(before - missing)
            # then filter out data_mapping without available judgers
            self.data_judger_mapping = {
                source: judgers
                for source, judgers in self.data_judger_mapping.items()
                if len(judgers) > 0
            }

        # Try build judgers in __init__ so that raise Exceptions earlly
        for judger_name, judger_conf in self.judgers_config.items():
            _ = self._build_judger(judger_name, judger_conf)

        self._processes: List[Process] = []
        self._stop_event = Event()
        atexit.register(self.shutdown)

        self._input_queues: Dict[str, Queue[GenericTask[InputData]]] = {
            judger_name: Queue() for judger_name in self.judgers_config.keys()
        }
        self._output_queue: Queue[GenericTask[Reward]] = Queue()
        self._exc_queue: Queue[Tuple[str, Exception]] = Queue()
        self._num_tasks: Dict[str, int] = {}  # for each token
        self._num_indexes: Dict[str, int] = {}  # for each token
        self._results_buffer: Dict[str, List[GenericTask[Reward]]] = defaultdict(
            list
        )  # results buffer grouped by the key "token"

    def submit(self, data_batch: List[InputData]):
        indexes_for_ext: List[int] = []
        indexes_for_local: List[int] = []
        tasks_input: List[GenericTask[InputData]] = []
        token = str(uuid4())
        for index, data_item in enumerate(data_batch):
            if (
                not isinstance(data_item, dict)
                or "metadata" not in data_item
                or "prompt_messages" not in data_item
                or "completion_messages" not in data_item
            ):
                indexes_for_local.append(index)
                continue
            source = data_item["metadata"].get("data_source", None)
            if source is None or source not in self.data_judger_mapping:
                indexes_for_local.append(index)
                continue
            indexes_for_ext.append(index)
            for judger in self.data_judger_mapping[source]:
                if judger == "RM":
                    indexes_for_local.append(index)
                else:
                    tasks_input.append(
                        GenericTask(
                            token=token,
                            index=index,
                            judger=judger,
                            content=data_item,
                        )
                    )

        self._num_tasks[token] = len(tasks_input)
        self._num_indexes[token] = len(data_batch)
        for task in tasks_input:
            self._input_queues[task.judger].put(task, block=True, timeout=1)

        if not self._processes:
            self.logger.debug("Starting processes...")
            for judger_name, judger_conf in self.judgers_config.items():
                num_proc = judger_conf.pop("num_processes", 1)
                self._processes.extend(
                    [
                        Process(
                            target=ParallelRouter._safe_process_worker,
                            kwargs={
                                "stop_event": self._stop_event,
                                "judger_name": judger_name,
                                "judger_conf": judger_conf,
                                "input_queue": self._input_queues[judger_name],
                                "output_queue": self._output_queue,
                                "exc_queue": self._exc_queue,
                                "config": self.subprocess_config,
                            },
                            daemon=True,
                        )
                        for _ in range(num_proc)
                    ]
                )
            for p in self._processes:
                p.start()
            self.logger.debug(f"Start processes done, total {len(self._processes)}")

        return token, indexes_for_local

    def query(
        self, token: str, timeout: float = 0
    ) -> Optional[List[Optional[Dict[str, Reward]]]]:
        start = time.time()
        while True:
            self._try_catch_subprocess_exceptions()
            try:
                result = self._output_queue.get(timeout=0.1)
                self._results_buffer[result.token].append(result)
            except queue.Empty:
                pass
            if len(self._results_buffer[token]) == self._num_tasks[token]:
                results = self._results_buffer.pop(token)
                num_tasks = self._num_tasks.pop(token)
                num_indexes = self._num_indexes.pop(token)
                rewards: List[Dict[str, Reward]] = [{} for _ in range(num_indexes)]
                for result in results:
                    reward = result.content
                    if result.judger in rewards[result.index]:
                        self.logger.warning(
                            f"{result.judger} already exists: {rewards[result.index]}, "
                            f"will replace --> {reward}"
                        )
                    rewards[result.index][result.judger] = reward
                # convert empty dicts to None
                return [r or None for r in rewards]
            if timeout > 0 and (time.time() - start) > timeout:
                raise TimeoutError(
                    f"Timeout after {timeout} seconds, got {len(self._results_buffer[token])} results, expected {self._num_tasks[token]}"
                )

    @staticmethod
    def _safe_process_worker(
        stop_event: EventClass,
        judger_name: str,
        judger_conf: dict,
        input_queue: "Queue[GenericTask[InputData]]",
        output_queue: "Queue[GenericTask[Reward]]",
        exc_queue: "Queue[Tuple[str, Exception]]",
        config: SubprocessConfig,
    ):
        try:
            ParallelRouter._process_worker(
                stop_event=stop_event,
                judger_name=judger_name,
                judger_conf=judger_conf,
                input_queue=input_queue,
                output_queue=output_queue,
                exc_queue=exc_queue,
                config=config,
            )
        except Exception as e:
            exc_queue.put((judger_name, e), timeout=1)

    @staticmethod
    def _process_worker(
        stop_event: EventClass,
        judger_name: str,
        judger_conf: dict,
        input_queue: "Queue[GenericTask[InputData]]",
        output_queue: "Queue[GenericTask[Reward]]",
        exc_queue: "Queue[Tuple[str, Exception]]",
        config: SubprocessConfig,
    ):
        from xtuner._lite import get_logger

        logger = get_logger()
        if config.loguru_handlers is not None:
            for handler in config.loguru_handlers:
                handler["enqueue"] = True
                logger.add(*handler)
        if config.worker_init_func is not None:
            config.worker_init_func()

        # Infer num threads for each stage according to configs
        _num_threads = judger_conf.pop("concurrency_per_proc", (1, 1))
        if isinstance(_num_threads, (tuple, list)) and len(_num_threads) == 2:
            num_threads_s1, num_threads_s2 = _num_threads
        elif isinstance(_num_threads, int):
            num_threads_s1 = max(1, _num_threads // 2)
            num_threads_s2 = max(1, _num_threads - num_threads_s1)
        else:
            raise TypeError(
                "`concurrency_per_proc` in judger_conf should be int or "
                f"Tuple[int, int], got {type(_num_threads)}: {_num_threads}"
            )

        # Lazy build judgers in subprocesses to avoid serialization errors
        judger = ParallelRouter._build_judger(judger_name, judger_conf)
        # input_queue = self._input_queues[judger_name]
        # output_queue = self._output_queue
        handle_queue: queue.Queue[GenericTask[JudgeStatus]] = queue.Queue()
        log_prefix = f"[pid={os.getpid()},{judger_name}]"

        def report_exc_wrapper(func):
            @functools.wraps(func)
            def wrapper(*args, **kwargs):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    stack_trace = traceback.format_exc()
                    logger.error(
                        f"{log_prefix} "
                        f"Thread worker of {judger_name} raised "
                        f"{type(e).__name__}: {e}",
                        f"Stack trace: {stack_trace}",
                    )
                    exc_queue.put((judger_name, e), timeout=1)

            return wrapper

        # Stage 1: input_queue -> judger.on_data_received -> handle_queue
        @report_exc_wrapper
        def thread_worker_s1():
            while not stop_event.is_set():
                try:
                    task = input_queue.get(timeout=0.1)
                    logger.debug(f"{log_prefix} dequeue input: {task}")
                except queue.Empty:
                    logger.debug(f"{log_prefix} input queue empty")
                    time.sleep(0.1)
                    continue
                data = task.content
                if "metadata" not in data:
                    raise RuntimeError(
                        f"'metadata' not in data.keys(): {list(data.keys())}"
                    )
                logger.debug(f"{log_prefix} on_data_received")
                handle = judger.on_data_received(
                    data["prompt_messages"],
                    data["completion_messages"],
                    cast(dict, data["metadata"]),
                )
                logger.debug(f"{log_prefix} got handle")
                new_task = GenericTask(
                    token=task.token,
                    index=task.index,
                    judger=task.judger,
                    content=handle,
                )
                while True:
                    try:
                        handle_queue.put(
                            new_task,
                            timeout=0.1,
                        )
                        logger.debug(f"{log_prefix} enqueue handle: {new_task}")
                        break
                    except queue.Full:
                        time.sleep(0.1)

        # Stage 2: handle_queue -> judger.on_reward_required -> output_queue
        @report_exc_wrapper
        def thread_worker_s2():
            while not stop_event.is_set():
                try:
                    task = handle_queue.get(timeout=0.1)
                    logger.debug(f"{log_prefix} dequeue handle: {task}")
                except queue.Empty:
                    logger.debug(f"{log_prefix} handle queue empty")
                    time.sleep(0.1)
                    continue
                logger.debug(f"{log_prefix} on_reward_required")
                reward = judger.on_reward_required(task.content)
                logger.info(f"{log_prefix} got result")
                new_task = GenericTask(
                    token=task.token,
                    index=task.index,
                    judger=task.judger,
                    content=reward,
                )
                while True:
                    try:
                        output_queue.put(
                            new_task,
                            timeout=0.1,
                        )
                        logger.debug(f"{log_prefix} enqueue output: {new_task}")
                        break
                    except queue.Full:
                        time.sleep(0.1)

        from threading import Thread

        threads: List[Thread] = []
        for _ in range(num_threads_s1):
            threads.append(Thread(target=thread_worker_s1, daemon=True))
        for _ in range(num_threads_s2):
            threads.append(Thread(target=thread_worker_s2, daemon=True))
        for t in threads:
            t.start()
        for t in threads:
            t.join()

    @staticmethod
    def _build_judger(judger_name: str, judger_conf: dict):
        judger_conf = deepcopy(judger_conf)
        judger_conf.pop("num_processes", None)
        judger_conf.pop("concurrency_per_proc", None)
        _type = judger_conf.pop("type", None)
        if _type is None:
            _type = judger_name
        if _type not in registered_judgers:
            raise KeyError(
                f"{judger_name} use unregistered judger type: {_type}. "
                f"Available judgers are: {list(registered_judgers.keys())}"
            )
        cls = registered_judgers[_type]
        return cls(**judger_conf)

    def _try_catch_subprocess_exceptions(self):
        exc_handles: List[Tuple[str, Exception]] = []
        while True:
            try:
                exc_handle = self._exc_queue.get(timeout=0.001)
                exc_handles.append(exc_handle)
            except queue.Empty:
                break
        if exc_handles:
            error_message = "\n".join(
                [
                    f"- [{judger_name}] {type(exc).__name__}: {exc}"
                    for judger_name, exc in exc_handles
                ]
            )
            raise RuntimeError(
                "Following threads/processes raise exceptions unexpectedly:\n"
                f"{error_message}\n"
                "Program terminated"
            )

    def shutdown(self, timeout: float = 2.0):
        if not hasattr(self, "_processes") or not self._processes:
            return
        if not self._stop_event.is_set():
            self._stop_event.set()
        connection.wait([p.sentinel for p in self._processes], timeout=timeout)
        for p in self._processes:
            if p.is_alive():
                p.kill()
                p.join()
        self._processes = []


================================================
FILE: oreal/judgers/utils.py
================================================
# flake8: noqa
# isort: skip_file

import multiprocessing
import re
from math import isclose
from typing import Optional, Union
from collections import defaultdict, Counter

from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr


def extract_answer(pred_str: str, execute: bool = False) -> str:
    if re.search("\\boxed|boxed|\\box|box", pred_str):
        answer = re.split("\\boxed|boxed|\\box|box", pred_str)[-1]
        if len(answer) == 0:
            return ""
        elif answer[0] == "{":
            stack = 1
            a = ""
            for c in answer[1:]:
                if c == "{":
                    stack += 1
                    a += c
                elif c == "}":
                    stack -= 1
                    if stack == 0:
                        break
                    a += c
                else:
                    a += c
        else:
            a = answer.split("$")[0].strip()
    elif re.search("[Tt]he (final )?answer is:?", pred_str):
        a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
    else:  # use the last number
        pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
        if len(pred) >= 1:
            a = pred[-1]
        else:
            a = ""
    choice = re.findall(r"([A-E]):\s*(.*)", a)
    if len(choice) > 0:
        for option, content in choice:
            a = option
    choice = re.findall(r"\(([A-E])\)\s*(.*)", a)
    if len(choice) > 0:
        for option, content in choice:
            a = option

    a = re.split(r"=|\\approx|≈", a)[-1]

    # multiple lines
    answer = ""
    preds = re.split("\n", a)
    for pred in preds:
        if "\\begin{align" in pred or pred.endswith(":"):
            continue
        if pred != "" and pred[0] == ":":
            pred = pred[1:]
        if pred != "" and pred[-1] == ".":
            pred = pred[:-1]
        if pred != "" and pred[-1] == "/":
            pred = pred[:-1]
        pred = strip_string(pred)
        pred = re.sub(r"^[a-zA-Z0-9]+[\)]\s*", "", pred)
        for p in pred.split("{}"):
            if p != "":
                pred = p
                break

        pred = re.sub(r"^\{([A-Z])\}|\(([A-Z])\)", r"\1\2", pred)
        if pred != "":
            answer = pred
            break
    return answer


def _fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if len(substr) > 0 and substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except Exception:
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string


def _fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        if "sqrt" not in a:
            a = int(a)
        if "sqrt" not in b:
            b = int(b)
        assert string == f"{a}/{b}"
        new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
        return new_string
    except Exception:
        return string


def _fix_sqrt(string):
    _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
    return _string


def strip_string(string):
    string = str(string).strip()
    # linebreaks
    string = string.replace("\n", "")

    # right "."
    string = string.rstrip(".")

    # remove inverse spaces
    string = string.replace("\\!", "")
    string = string.replace("\\ ", "")

    # replace \\ with \
    string = string.replace("\\\\", "\\")
    string = string.replace("\\\\", "\\")

    # replace tfrac and dfrac with frac
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")

    # remove \left and \right
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")

    # Remove unit: miles, dollars if after is not none
    _string = re.sub(r"\\text{.*?}$", "", string).strip()
    if _string != "" and _string != string:
        # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
        string = _string

    # Remove circ (degrees)
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")

    # remove dollar signs
    string = string.replace("\\$", "")
    string = string.replace("$", "")

    string = string.replace("\\text", "")
    string = string.replace("x\\in", "")

    # remove percentage
    string = string.replace("\\%", "")
    string = string.replace(r"\%", "")
    string = string.replace("%", "")

    # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")

    # cdot
    string = string.replace("\\cdot", "")

    # inf
    string = string.replace("infinity", "\\infty")
    if "\\infty" not in string:
        string = string.replace("inf", "\\infty")
    string = string.replace("+\\inity", "\\infty")

    # and
    string = string.replace("and", "")
    string = string.replace("\\mathbf", "")

    # use regex to remove \mbox{...}
    string = re.sub(r"\\mbox{.*?}", "", string)

    # quote
    string.replace("'", "")
    string.replace('"', "")

    # i, j
    if "j" in string and "i" not in string:
        string = string.replace("j", "i")

    # replace a.000b where b is not number or b is end, with ab, use regex
    string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
    string = re.sub(r"(\d+)\.0+$", r"\1", string)

    # if empty, return empty string
    if len(string) == 0:
        return string
    if string[0] == ".":
        string = "0" + string

    # to consider: get rid of e.g. "k = " or "q = " at beginning
    if len(string.split("=")) == 2:
        if len(string.split("=")[0]) <= 2:
            string = string.split("=")[1]

    string = _fix_sqrt(string)
    string = string.replace(" ", "")

    # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
    string = _fix_fracs(string)

    # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
    string = _fix_a_slash_b(string)

    return string


def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval


def extract_answer(pred_str: str, execute: bool = False) -> str:
    if re.search("\boxed|boxed", pred_str):
        answer = re.split("\boxed|boxed", pred_str)[-1]
        if len(answer) == 0:
            return ""
        elif answer[0] == "{":
            stack = 1
            a = ""
            for c in answer[1:]:
                if c == "{":
                    stack += 1
                    a += c
                elif c == "}":
                    stack -= 1
                    if stack == 0:
                        break
                    a += c
                else:
                    a += c
        else:
            a = answer.split("$")[0].strip()
    elif re.search("[Tt]he (final )?answer is:?", pred_str):
        a = re.split("[Tt]he (final )?answer is:?", pred_str)[-1].strip().rstrip(".")
    elif pred_str.startswith("```python") and execute:
        # fall back to program
        from lagent import get_tool

        a = get_tool("IPythonInteractive").exec(pred_str).value or ""
    else:  # use the last number
        pred = re.findall(r"-?\d*\.?\d+", pred_str.replace(",", ""))
        if len(pred) >= 1:
            a = pred[-1]
        else:
            a = ""
    # multiple lines
    pred = a.split("\n")[0]
    if pred != "" and pred[0] == ":":
        pred = pred[1:]
    if pred != "" and pred[-1] == ".":
        pred = pred[:-1]
    if pred != "" and pred[-1] == "/":
        pred = pred[:-1]
    pred = strip_string(pred)
    return pred


def is_digit(s):
    try:
        float(str(s).replace(",", ""))
        return True
    except ValueError:
        return False


def math_equal(
    prediction: Union[bool, float, str],
    reference: Union[float, str],
    include_percentage: bool = True,
    is_close: bool = True,
    tolerance: float = 1e-4,
    timeout: bool = False,
) -> bool:
    """Exact match of math if and only if:

    1. numerical equal: both can convert to float and are equal
    2. symbolic equal: both can convert to sympy expression and are equal
    """
    try:  # 1. numerical equal
        if is_digit(prediction) and is_digit(reference):
            prediction = float(str(prediction).replace(",", ""))
            reference = float(str(reference).replace(",", ""))
            # number questions
            if include_percentage:
                gt_result = [reference / 100, reference, reference * 100]
            else:
                gt_result = [reference]
            for item in gt_result:
                try:
                    if is_close:
                        if isclose(item, prediction, rel_tol=tolerance):
                            return True
                    else:
                        if item == prediction:
                            return True
                except Exception:
                    continue
            return False
    except Exception:
        pass

    if not prediction and prediction not in [0, False]:
        return False

    # 2. symbolic equal
    reference = str(reference).strip()
    prediction = str(prediction).strip()

    ## deal with [], (), {}
    pred_str, ref_str = prediction, reference
    if (
        prediction.startswith("[")
        and prediction.endswith("]")
        and not reference.startswith("(")
    ) or (
        prediction.startswith("(")
        and prediction.endswith(")")
        and not reference.startswith("[")
    ):
        pred_str = pred_str.strip("[]()")
        ref_str = ref_str.strip("[]()")
    for s in ["{", "}", "(", ")"]:
        ref_str = ref_str.replace(s, "")
        pred_str = pred_str.replace(s, "")
    if pred_str == ref_str:
        return True

    ## [a, b] vs. [c, d], return a==c and b==d
    if (
        (prediction.startswith("[") and prediction.endswith("]"))
        and (reference.startswith("[") and reference.endswith("]"))
        or (prediction.startswith("(") and prediction.endswith(")"))
        and (reference.startswith("(") and reference.endswith(")"))
    ):
        pred_parts = prediction[1:-1].split(",")
        ref_parts = reference[1:-1].split(",")
        if len(pred_parts) == len(ref_parts):
            if all(
                [
                    math_equal(
                        pred_parts[i], ref_parts[i], include_percentage, is_close
                    )
                    for i in range(len(pred_parts))
                ]
            ):
                return True

    # symbolic equal with sympy
    if timeout:
        if call_with_timeout(symbolic_equal_process, prediction, reference):
            return True
    else:
        if symbolic_equal(prediction, reference):
            return True

    return False


def math_equal_process(param):
    return math_equal(param[-2], param[-1])


def math_equal_process_v2(param):
    if param[-2] is None:
        return False
    return math_equal(param[-2], param[-1])


def symbolic_equal(a, b):

    def _parse(s):
        for f in [parse_latex, parse_expr]:
            try:
                return f(s)
            except Exception:
                pass
        return s

    a = _parse(a)
    b = _parse(b)

    try:
        if simplify(a - b) == 0:
            return True
    except Exception:
        pass

    try:
        if isclose(N(a), N(b), rel_tol=1e-3):
            return True
    except Exception:
        pass
    return False


def symbolic_equal_process(a, b, output_queue):
    result = symbolic_equal(a, b)
    output_queue.put(result)


def call_with_timeout(func, *args, timeout=1, **kwargs):
    output_queue = multiprocessing.Queue()
    process_args = args + (output_queue,)
    process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
    process.start()
    process.join(timeout)

    if process.is_alive():
        process.terminate()
        process.join()
        return False

    return output_queue.get()


def math_majority_vote(answers: list, majority: Optional[int] = None):
    # threshold = len(answers) // 2 + 1
    ans2cnt, ans2idx = Counter(), defaultdict(list)
    for i, ans in enumerate(answers):
        if isinstance(ans, str) and ans.strip():
            for key in ans2cnt.keys():
                if math_equal(ans, key):
                    ans2cnt[key] += 1
                    ans2idx[key].append(i)
                    break
            else:
                ans2cnt[ans] += 1
                ans2idx[ans].append(i)
    if ans2cnt:
        maj, cnt = ans2cnt.most_common(1)[0]
        if maj and cnt >= (majority or 1):
            return maj, ans2idx[maj]
    return None, []


================================================
FILE: oreal/utils.py
================================================
import importlib.util
import os
import types


class ConfigDict(dict):

    def __getattr__(self, item):
        if item in self:
            return self[item]
        raise AttributeError(f"'ConfigDict' object has no attribute '{item}'")

    def __setattr__(self, key, value):
        self[key] = value


class Config:

    @staticmethod
    def fromfile(file_path):
        config_dict = ConfigDict()
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"Config file not found: {file_path}")

        # Load the configuration file as a module
        spec = importlib.util.spec_from_file_location("config_module", file_path)
        config_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(config_module)

        # Function to convert nested dictionaries to ConfigDict recursively
        def convert_to_config_dict(d):
            if isinstance(d, dict):

                config_dict = ConfigDict()
                for key, value in d.items():
                    if isinstance(value, dict):
                        config_dict[key] = convert_to_config_dict(value)
                    else:
                        config_dict[key] = value
                return config_dict
            else:
                return d

        # Retrieve all attributes (variables) from the module
        for attribute_name in dir(config_module):
            if not attribute_name.startswith("__"):
                config_dict[attribute_name] = convert_to_config_dict(
                    getattr(config_module, attribute_name)
                )
        for key, value in list(config_dict.items()):
            if isinstance(value, (types.FunctionType, types.ModuleType)):
                config_dict.pop(key)
        return config_dict


================================================
FILE: requirements.text
================================================
fire
flash-attn
torch>=2.5.0
xtuner[all]==0.2.0rc0


================================================
FILE: train_oreal.py
================================================
# Copyright (c) InternLM. All rights reserved.
import json
import os
import sys
import time
from collections import OrderedDict
from contextlib import nullcontext
from datetime import datetime, timedelta

import fire
import torch
import torch.distributed as dist
from mmengine import mkdir_or_exist
from mmengine.runner import set_random_seed
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils.import_utils import is_flash_attn_2_available
from xtuner._lite import get_device, get_logger, get_torch_device_module
from xtuner._lite.accelerate import profile_time_and_memory, unpack_sequence
from xtuner._lite.algorithms.sft import SftCollator
from xtuner._lite.modelings import register_remote_code
from xtuner._lite.parallel import (
    ParallelSampler,
    setup_parallel,
    split_for_sequence_parallel,
)
from xtuner._lite.patches import AutoPatch, FSDPConfig
from xtuner._lite.patches.utils import pad_to_max_length, pad_to_multiple_of

from oreal.datasets import (
    InferDataset,
    OrealPromptDataset,
    PromptCollator,
    TrajectoryCollator,
    TrajectoryDataset,
    TrajectoryDatasetWithFilter,
)
from oreal.judgers import ParallelRouter
from oreal.utils import Config

logger = get_logger()

DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()


torch._dynamo.config.cache_size_limit = 16384


class RLParallelSampler(ParallelSampler):
    def __iter__(self):
        """Iterate the indices."""
        # deterministically shuffle based on epoch and seed
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = torch.arange(len(self.dataset)).tolist()

        # add extra samples to make it evenly divisible
        if self.round_up:
            indices = (indices * int(self.total_size / len(indices) + 1))[
                : self.total_size
            ]

        # subsample
        chunk_size = len(indices) // self.world_size
        start = self.rank * chunk_size
        end = start + chunk_size
        indices = indices[start:end]

        return iter(indices[self.step :])


def log_format(rank, debug=False):

    formatter = f"[XTuner][RANK {rank}]"
    formatter += "[{time:YYYY-MM-DD HH:mm:ss}][<level>{level}</level>]"

    if debug:
        formatter += "[<cyan>{name}</cyan>:"
        formatter += "<cyan>{function}</cyan>:"
        formatter += "<cyan>{line}</cyan>]"

    formatter += " <level>{message}</level>"
    return formatter


def is_interval(step, total_steps, interval):
    return (step + 1) % interval == 0 or (step + 1) == total_steps


def reduce_mean(data, group):
    data_tensor = torch.tensor(data, device=DEVICE)
    dist.all_reduce(data_tensor, op=dist.ReduceOp.AVG, group=group)
    return data_tensor.item()


def threshold_rescale(prob, threshold=0.5):
    prob = prob - threshold
    prob = prob / (1 - threshold)
    prob = prob.clamp(0, 1)
    return prob


def topk_rescale(prob, topk_ratio=0.5):
    topk_num = int(prob.numel() * topk_ratio)
    values, indices = torch.topk(prob, topk_num)
    result = torch.zeros_like(prob)
    if values.max() != values.min():
        normalized_values = (values - values.min()) / (values.max() - values.min())
    else:
        normalized_values = torch.ones_like(values)
    result[indices] = normalized_values
    return result


def train_oreal(cfg_path, **kwargs):
    args = Config.fromfile(cfg_path)
    args.update(kwargs)

    ###########################################################################
    #                           1. Environment                                #
    ###########################################################################
    register_remote_code()

    setup_parallel()
    set_random_seed(args.seed)

    rank = dist.get_rank()

    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")

    objects = [timestamp]
    dist.broadcast_object_list(objects, src=0)
    timestamp = objects[0]

    args.work_dir = os.path.join(args.work_dir, timestamp)
    mkdir_or_exist(args.work_dir)

    log_file = os.path.join(args.work_dir, f"rank{rank}.log")

    # Change the log format printed in the terminal
    lvl = "DEBUG" if args.debug else "INFO"
    logger.remove()
    logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug))
    # Change the format saved in the log file
    logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)

    logger.info(args)
    if rank == 0:
        env = collect_env()
        import transformers
        import xtuner

        env["Transformers"] = transformers.__version__
        env["XTuner"] = f"{xtuner.__version__}+{get_git_hash(digits=6)}"
        runtime_env = OrderedDict()
        runtime_env.update(env)
        runtime_env["Seed"] = args.seed
        runtime_env["World Size"] = dist.get_world_size()

        runtime_env_info = "\n    " + "\n    ".join(f"{k}: {v}" for k, v in runtime_env.items())
        dash_line = "-" * 60
        logger.info("\n" + dash_line + "\nRuntime environment:" + runtime_env_info + "\n" + dash_line + "\n")
    # -------------------    Environment  End  ------------------------------ #

    ###########################################################################
    #                          3. FSDP                                        #
    ###########################################################################
    if args.dtype == "auto":
        args.dtype = "bf16" if DEVICE_MODULE.is_bf16_supported() else "fp16"

    if args.dtype == "fp16":
        dtype = torch.float16
    elif args.dtype == "bf16":
        if DEVICE_MODULE.is_bf16_supported():
            dtype = torch.bfloat16
        else:
            raise RuntimeError("The device does not support `bf16`, " "please set `dtype` to `fp16`.")
    else:
        raise RuntimeError("`dtype` only supports `fp16`, `bf16` or `auto`, " f"but found {args.dtype}.")

    with torch.device("meta"):
        # In order to save CPU memory and GPU memory,
        # initialize an empty complete model on all ranks first.
        # At the same time, a non-empty complete model will be loaded
        # on the CPU of rank0.
        # After the model is parallelized, the parameters of the complete
        # model on rank0 will be loaded.
        actor_model = AutoModelForCausalLM.from_pretrained(args.actor, attn_implementation="flash_attention_2", torch_dtype=dtype)

        for module in actor_model.modules():
            for p_name, param in module.named_parameters(recurse=False):
                if param.requires_grad:
                    param_fp32 = torch.nn.Parameter(param.to(dtype=torch.float32))
                    setattr(module, p_name, param_fp32)

        ref_model = AutoModelForCausalLM.from_pretrained(args.reference, attn_implementation="flash_attention_2", torch_dtype=dtype)

        for param in ref_model.parameters():
            param.requires_grad = False

        if args.token_level_rm is not None:
            token_level_rm = AutoModelForCausalLM.from_pretrained(
                args.token_level_rm, attn_implementation="flash_attention_2", torch_dtype=dtype
            )
            # replace the language model head with a reward model linear head
            token_level_rm.lm_head = torch.nn.Linear(token_level_rm.config.hidden_size, 1, bias=False)

            for module in token_level_rm.modules():
                for p_name, param in module.named_parameters(recurse=False):
                    if param.requires_grad:
                        # Ensure all numerical values in the optimizer are fp32.
                        # Don't worry about speed, FSDP will use `dtype`
                        # during forward.
                        param_fp32 = torch.nn.Parameter(param.to(dtype=torch.float32))
                        setattr(module, p_name, param_fp32)

    with profile_time_and_memory("[Parallelize Actor]"):
        actor_model = AutoPatch.from_causal_lm(
            actor_model,
            fsdp_config=FSDPConfig(
                tp_size=args.tp_size,
                sp_size=args.sp_size,
                param_dtype=dtype,
                reduce_dtype=dtype,
                cpu_offload=args.cpu_offload,
                reshard_after_forward=False,
                mesh_prefix="actor",
            ),
        )
    dist.barrier()

    with profile_time_and_memory("[Parallelize Reference]"):
        ref_model = AutoPatch.from_causal_lm(
            ref_model,
            fsdp_config=FSDPConfig(
                tp_size=args.tp_size,
                sp_size=args.sp_size,
                param_dtype=dtype,
                reduce_dtype=dtype,
                cpu_offload=args.cpu_offload,
                reshard_after_forward=True,
                mesh_prefix="ref",
            ),
        )
    dist.barrier()

    if args.token_level_rm is not None:
        with profile_time_and_memory("[Parallelize Reward]"):
            token_level_rm = AutoPatch.from_causal_lm(
                token_level_rm,
                fsdp_config=FSDPConfig(
                    tp_size=args.tp_size,
                    sp_size=args.sp_size,
                    param_dtype=dtype,
                    reduce_dtype=dtype,
                    cpu_offload=args.cpu_offload,
                    reshard_after_forward=True,
                    mesh_prefix="reward",
                ),
            )
            token_level_rm.train()
            # print head weight
            # logger.info(f"Rank {rank} Reward model head weight: {token_level_rm.patched_model.lm_head.weight}")
    dist.barrier()
    # --------------------------    FSDP  End  ------------------------------ #

    ###########################################################################
    #                     2. Dataset & Dataloader                             #
    ###########################################################################
    actor_sp_mesh = actor_model.sequence_parallel_mesh
    actor_dp_mesh = actor_model.data_parallel_mesh
    actor_data_mesh = actor_model.data_mesh
    actor_dp_size = actor_dp_mesh.size()

    actor_sp_size = actor_sp_mesh.size()

    prompt_global_batch = args.gen_global_batch // args.prompt_repeat_k

    tokenizer = AutoTokenizer.from_pretrained(args.actor, trust_remote_code=True, padding_side="right")

    if args.chat_template is not None:
        if rank == 0:
            logger.info(f"[CHAT_TEMPLATE] {args.chat_template}")
        tokenizer.chat_template = args.chat_template

    stop_token_ids = []
    word_ids = tokenizer.encode(args.stop_word, add_special_tokens=False)
    if len(word_ids) > 1:
        raise NotImplementedError("The stop word must be a single token.")
    stop_token_ids.append(word_ids[0])

    with profile_time_and_memory("[Dataset & Dataloader]"):

        prompt_dataset = OrealPromptDataset(
            args.datasets,
            tokenizer,
            difficulty_balance_cfg=args.data_difficulty_balance_cfg,
        )
        if rank == 0:
            logger.info(f"[Dataset] {len(prompt_dataset)} prompts.")

        assert is_flash_attn_2_available()
        prompt_collator = PromptCollator(pack_batch=True)
        prompt_sampler = ParallelSampler(prompt_dataset, actor_dp_mesh, prompt_global_batch, shuffle=True)

        prompt_dataloader = DataLoader(
            prompt_dataset,
            batch_size=prompt_global_batch // actor_dp_mesh.size(),
            num_workers=args.num_workers,
            # Ensure to round up or drop last based on the `global_batch_size`,
            # if you want to replace a custom sampler.
            sampler=prompt_sampler,
            collate_fn=prompt_collator,
            persistent_workers=args.num_workers > 0,
        )

        if rank == 0:
            logger.info(f"[Dataloader] {len(prompt_dataloader)} batches.")
            _first_batch = [prompt_dataset[i] for i in range(prompt_global_batch)]
            logger.debug(f"[Dataloader] Training Batch:\n{_first_batch}")

    dist.barrier()
    # -------------------    Dataset & Dataloader  End  --------------------- #

    # ---------------------    Router  Start  ------------------------------- #
    judger_router = ParallelRouter(
        judgers_config=args.judgers_config,
        data_judger_mapping=args.data_judger_mapping,
        logger=logger,
    )

    ###########################################################################
    #                      4. Optimizer & Scheduler                           #
    ###########################################################################
    actor_params = [p for p in actor_model.parameters() if p.requires_grad]
    actor_optimizer = AdamW(actor_params, lr=args.actor_lr, weight_decay=args.wd)

    if args.token_level_rm is not None:
        token_rm_params = [p for p in token_level_rm.parameters() if p.requires_grad]
        token_rm_optimizer = AdamW(token_rm_params, lr=args.token_level_rm_lr, weight_decay=args.wd)

    total_steps = args.total_steps
    if total_steps > len(prompt_dataloader):
        logger.warning(f"Total steps {total_steps} is greater than the number of prompts {len(prompt_dataloader)}, set to dataloader length.")
        total_steps = len(prompt_dataloader)

    warmup_steps = args.warmup_steps
    rm_warmup_steps = args.get("rm_warmup_steps", warmup_steps)
    lr_min = args.get("actor_min_lr", args.actor_lr)
    token_level_rm_lr_min = args.get("token_level_rm_lr_min", args.token_level_rm_lr)

    if args.checkpoint_interval == -1:
        checkpoint_interval = total_steps
    elif args.checkpoint_interval < 1:
        checkpoint_interval = int(total_steps * args.checkpoint_interval)
    else:
        checkpoint_interval = int(args.checkpoint_interval)

    def warmup_fn(x):
        return x / warmup_steps if x < warmup_steps else 1

    warmup_scheduler = LambdaLR(actor_optimizer, warmup_fn)
    cosine_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_steps - warmup_steps, eta_min=lr_min)

    if args.token_level_rm is not None:

        def rm_warmup_fn(x):
            return x / rm_warmup_steps if x < rm_warmup_steps else 1

        token_rm_warmup_scheduler = LambdaLR(token_rm_optimizer, rm_warmup_fn)
        token_rm_cosine_scheduler = CosineAnnealingLR(token_rm_optimizer, T_max=total_steps - rm_warmup_steps, eta_min=token_level_rm_lr_min)

    # ----------------    Optimizer & Scheduler End   ----------------------- #

    ###########################################################################
    #                          5. Training                                    #
    ###########################################################################

    if args.filter_trajectory:
        trajectory_dataset = TrajectoryDatasetWithFilter(repeat_k=args.prompt_repeat_k)
    else:
        trajectory_dataset = TrajectoryDataset()

    prompt_iterator = iter(prompt_dataloader)

    start_step = 0
    start_train_t = time.time()
    DEVICE_MODULE.empty_cache()
    DEVICE_MODULE.reset_peak_memory_stats()
    max_memory = DEVICE_MODULE.max_memory_allocated()
    logger.info("[Train] Begin Train Loop. The current GPU memory is " f"{(max_memory / 1024**3):.1f}GB")

    for step in range(start_step, total_steps):

        if step <= warmup_steps:
            warmup_scheduler.step()
            cur_lr = warmup_scheduler.get_last_lr()[0]
            if args.token_level_rm is not None:
                token_rm_warmup_scheduler.step()
                token_rm_cur_lr = token_rm_warmup_scheduler.get_last_lr()[0]
        else:
            cosine_scheduler.step()
            cur_lr = cosine_scheduler.get_last_lr()[0]
            if args.token_level_rm is not None:
                token_rm_cosine_scheduler.step()
                token_rm_cur_lr = token_rm_cosine_scheduler.get_last_lr()[0]

        DEVICE_MODULE.reset_peak_memory_stats()

        step_kl_penalty_loss = 0
        step_rl_loss = 0
        step_token_level_rm_loss = 0
        step_start_t = time.time()
        step_positive_loss = 0
        step_negative_loss = 0

        if step < args.actor_freeze_steps:
            # Only update the parameters of the token-level reward model
            update_actor = False
        else:
            update_actor = True

        DEVICE_MODULE.reset_peak_memory_stats()

        data = next(prompt_iterator)
        prompt_input_ids = unpack_sequence(data["input_ids"].to(DEVICE), data["num_tokens"])
        infer_num_tokens = data["num_tokens"].to(DEVICE)
        # repeat prompt for k times
        prompt_input_ids = [p for p in prompt_input_ids for _ in range(args.prompt_repeat_k)]  # AAAABBBBCCCC
        infer_num_tokens = torch.Tensor([n for n in infer_num_tokens for _ in range(args.prompt_repeat_k)])
        message_data = [m for m in data["message_data"] for _ in range(args.prompt_repeat_k)]
        metadata = [m for m in data["metadata"] for _ in range(args.prompt_repeat_k)]

        # Stage 1,  Actor Model Generation
        step_avg_new_tokens = 0
        step_gen_start_t = time.time()

        actor_model.eval()
        # During the generation stage, sequence parallelism was not used,
        # even when the sp size is greater than 1.
        # Per sp rank processes different prompts in parallel.
        responses = actor_model.generate(
            prompt_input_ids,
            stop_token_ids,
            max_length=args.gen_max_length,
            max_batch_size=len(prompt_input_ids),
            max_prefill_batch=args.max_prefill_batch,
            max_new_tokens=args.gen_max_new,
            do_sample=args.gen_do_sample,
            top_k=args.gen_top_k,
            top_p=args.gen_top_p,
            temperature=args.temperature,
            cuda_graph=args.cuda_graph,
        )

        # decode responses
        response_texts = [tokenizer.decode(res, skip_special_tokens=False) for res in responses]

        actor_model.train()
        dist.barrier()

        step_avg_new_tokens = sum([len(res) for res in responses]) / len(responses)
        step_gen_time = time.time() - step_gen_start_t

        prompt_input_ids = [p[0].tolist() for p in prompt_input_ids]

        # Stage 2,  Infer
        step_infer_start_t = time.time()
        step_infer_consumed_tokens = 0

        # submit to judger
        if actor_data_mesh.get_local_rank() == 0:
            submit_batch = []
            for i in range(len(message_data)):
                submit_batch.append(
                    {
                        "prompt_messages": message_data[i],
                        "completion_messages": [{"role": "assistant", "content": response_texts[i]}],
                        "metadata": metadata[i],
                    }
                )
            token, indexes_for_local = judger_router.submit(submit_batch)

        # `infer_dataset` varies at each dp rank, there is no need to
        # use the parallel sampler.
        infer_dataset = InferDataset(
            prompt_input_ids,
            responses,
            message_data,
            metadata,
        )
        infer_dataloader = DataLoader(
            infer_dataset,
            batch_size=args.rl_mirco_batch,
            num_workers=0,
            collate_fn=SftCollator(pack_batch=True),
            shuffle=False,
            persistent_workers=False,
        )

        policies = []
        for infer_packed_seq in infer_dataloader:
            # labels are already shifted in InferDataset
            infer_labels = infer_packed_seq["labels"].to(DEVICE)
            infer_input_ids = infer_packed_seq["input_ids"].to(DEVICE)
            infer_num_tokens = infer_packed_seq["num_tokens"].to(DEVICE)
            infer_batch_size = infer_num_tokens.numel()

            step_infer_consumed_tokens += infer_num_tokens.sum() / actor_data_mesh.size()

            unpacked_input_ids = unpack_sequence(infer_input_ids, infer_num_tokens, dim=1)
            unpacked_labels = unpack_sequence(infer_labels, infer_num_tokens, dim=1)

            for i in range(infer_batch_size):
                assert unpacked_input_ids[i].numel() == infer_num_tokens[i]
                assert unpacked_labels[i].numel() == infer_num_tokens[i]

                _policy = {
                    "input_ids": unpacked_input_ids[i].flatten().tolist(),
                    "labels": unpacked_labels[i].flatten().tolist(),
                    "num_tokens": infer_num_tokens[i].item(),
                }
                _policy["sequence_text"] = tokenizer.decode(_policy["input_ids"], skip_special_tokens=False)
                policies.append(_policy)

        step_infer_time = time.time() - step_infer_start_t

        # ------------------------------------------------------------- #
        # --------------------------Get Judger Reward------------------ #
        # ------------------------------------------------------------- #
        # query results from judger
        if actor_data_mesh.get_local_rank() == 0:
            while True:
                try:
                    judger_results = judger_router.query(token, timeout=3)
                    logger.info(f"Query judger results: {judger_results}")
                    break
                except TimeoutError as e:
                    logger.info(f"Judger query timeout: {e}. Will retry")
            judger_rewards = [list(r.values())[0] for r in judger_results]
            judger_rewards = [r if r is not None else -1.0 for r in judger_rewards]
            judger_rewards = torch.tensor(judger_rewards, dtype=torch.float32).to(DEVICE)
        else:
            judger_rewards = torch.tensor([0] * len(policies), dtype=torch.float32).to(DEVICE)

        dist.barrier()
        # broadcast judger rewards to same data mesh
        dist.all_reduce(judger_rewards, op=dist.ReduceOp.SUM, group=actor_data_mesh.get_group())

        # reward shaping, use GRPO or RLOO to normalize rewards
        _rewards = judger_rewards.reshape(-1, args.prompt_repeat_k).T
        if args.reward_shaping_type == "rloo":
            baseline = (_rewards.sum(0) - _rewards) / (args.prompt_repeat_k - 1)
            judger_advantages = _rewards - baseline
        elif args.reward_shaping_type == "grpo":
            judger_advantages = (_rewards - _rewards.mean(0)) / (_rewards.std(0) + 1e-8)
        else:
            raise NotImplementedError(f"Reward shaping type {args.reward_shaping_type} is not implemented.")
        judger_advantages = judger_advantages.T.flatten()
        # update policies
        assert len(judger_rewards) == len(policies)
        for i in range(len(policies)):
            policies[i]["judger_reward"] = judger_rewards[i].item()
            policies[i]["judger_advantage"] = judger_advantages[i].item()

        # ------------------------------------------------------------- #
        # --------------------------Stage 4, RL------------------------ #
        # ------------------------------------------------------------- #
        # Stage 4, RL
        step_rl_start_t = time.time()

        _global_policies = [None] * actor_dp_size
        dist.all_gather_object(_global_policies, policies, actor_dp_mesh.get_group())

        global_policies = []
        for _rank_policies in _global_policies:
            global_policies.extend(_rank_policies)

        trajectory_dataset.update(global_policies)
        if rank == 0:
            # dump trajectory
            _buffer_dir = os.path.join(args.work_dir, "trajectories")
            mkdir_or_exist(_buffer_dir)
            _buffer_file = os.path.join(_buffer_dir, f"step.{step}.jsonl")
            trajectory_dataset.dump_jsonl(_buffer_file, tokenizer, args.debug)
            _buffer_log_file = os.path.join(_buffer_dir, f"step.{step}.log")
            trajectory_dataset.dump_log(_buffer_log_file, tokenizer, args.debug)

        rl_global_batch = args.rl_global_batch
        if args.filter_trajectory:
            _world_size = actor_dp_mesh.size()
            _data_size = len(trajectory_dataset)
            # train_global_batch is divisible by world_size
            rl_global_batch = _data_size // _world_size * _world_size

        rl_loader = DataLoader(
            trajectory_dataset,
            batch_size=args.rl_mirco_batch,
            num_workers=0,
            collate_fn=TrajectoryCollator(pack_batch=True),
            shuffle=False,
            sampler=RLParallelSampler(trajectory_dataset, actor_dp_mesh, rl_global_batch, shuffle=False),
            persistent_workers=False,
        )

        # Count the total number of tokens used for training RL on all ranks
        # It is necessary for `per-token` loss, otherwise the number of tokens
        # for each backward is unbalanced.
        global_action_tokens = trajectory_dataset.num_action_tokens
        global_positive_tokens = sum(
            [(torch.tensor(t["labels"]) >= 0).sum().item() for t in trajectory_dataset._trajectories if t["judger_reward"] > 0]
        )
        global_negative_tokens = global_action_tokens - global_positive_tokens
        global_num_seqs = len(trajectory_dataset._trajectories)

        step_avg_judger_reward = sum([t["judger_reward"] for t in global_policies]) / len(global_policies)
        step_sum_gen_entropy = 0
        step_sum_ref_kl = 0
        step_action_tokens = 0
        step_rl_consumed_tokens = 0

        step_sum_adv = 0

        for packed_policy in rl_loader:

            rl_input_ids = packed_policy["input_ids"].to(DEVICE)
            rl_num_tokens = packed_policy["num_tokens"].to(DEVICE)
            assert rl_input_ids.numel() == rl_num_tokens.sum()
            rl_batch_size = rl_num_tokens.numel()
            # labels are already shifted in InferDataset
            rl_labels = packed_policy["labels"].to(DEVICE)

            judger_rewards = torch.Tensor(packed_policy["judger_rewards"]).to(DEVICE)  # shape: (rl_mirco_batch, )
            judger_advantages = torch.Tensor(packed_policy["judger_advantages"]).to(DEVICE)  # shape: (rl_mirco_batch, )

            actor_input_ids = rl_input_ids.clone()
            actor_labels = rl_labels.clone()
            actor_num_tokens = rl_num_tokens.clone().tolist()

            actor_cu_seq_lens = torch.cumsum(torch.IntTensor([0] + actor_num_tokens), dim=0).to(DEVICE).int()
            actor_position_ids = [torch.arange(num) for num in actor_num_tokens]
            actor_position_ids = torch.cat(actor_position_ids, dim=0).to(DEVICE).unsqueeze_(0)

            with nullcontext() if update_actor else torch.no_grad():
                packed_actor_logits = actor_model(
                    input_ids=actor_input_ids,
                    position_ids=actor_position_ids,
                    use_cache=False,
                    cu_seq_lens_q=actor_cu_seq_lens,
                    cu_seq_lens_k=actor_cu_seq_lens,
                    max_length_q=max(actor_num_tokens),
                    max_length_k=max(actor_num_tokens),
                    sequence_parallel_mesh=actor_sp_mesh,
                ).logits

            # -------sft loss--------
            # calculate sft loss on each sp(tp) rank and then gather them to dp rank, avoid gather logits which may lead to OOM
            if actor_model.fsdp_config.torch_compile:
                _actor_labels = pad_to_max_length(actor_labels, -100, actor_model.fsdp_config.max_length, 1)
            else:
                if actor_sp_mesh and actor_sp_mesh.size() > 1:
                    multiple_of = actor_sp_mesh.size() * actor_model.tp_mesh.size()
                else:
                    multiple_of = actor_model.tp_mesh.size()
                _actor_labels = pad_to_multiple_of(actor_labels, -100, multiple_of, 1)

            if actor_sp_mesh and actor_sp_mesh.size() > 1:
                _actor_labels = split_for_sequence_parallel(_actor_labels, dim=1, sp_mesh=actor_sp_mesh)

            if actor_model.tp_mesh.size() > 1:
                _actor_labels = split_for_sequence_parallel(_actor_labels, dim=1, sp_mesh=actor_model.tp_mesh)
            packed_sft_loss = F.cross_entropy(packed_actor_logits.squeeze(), _actor_labels.squeeze(), reduction="none").unsqueeze(
                0
            )  # shape: 1, seqlen

            if actor_model.tp_mesh.size() > 1:
                _packed_sft_loss = dist.nn.all_gather(packed_sft_loss, group=actor_model.tp_mesh.get_group())
                packed_sft_loss = torch.cat(_packed_sft_loss, dim=1)

            if actor_sp_mesh and actor_sp_mesh.size() > 1:
                _packed_sft_loss = dist.nn.all_gather(packed_sft_loss, group=actor_sp_mesh.get_group())
                packed_sft_loss = torch.cat(_packed_sft_loss, dim=1)

            packed_sft_loss = packed_sft_loss[:, : actor_labels.size(1)]

            # The labels of prefill tokens and last token are -100.
            # HACK: (for sp) The -100 part takes the value of 0,
            # this part will be masked later.
            packed_logprobs = actor_model.gather_logprobs(packed_actor_logits, actor_labels.clip(0), actor_sp_mesh)

            logprobs = unpack_sequence(packed_logprobs, actor_num_tokens, dim=1)
            sft_loss = unpack_sequence(packed_sft_loss, actor_num_tokens, dim=1)

            ref_input_ids = rl_input_ids.clone()
            ref_labels = rl_labels.clone()
            ref_num_tokens = rl_num_tokens.clone().tolist()

            ref_cu_seq_lens = torch.cumsum(torch.IntTensor([0] + ref_num_tokens), dim=0).to(DEVICE).int()
            ref_position_ids = [torch.arange(num) for num in ref_num_tokens]
            ref_position_ids = torch.cat(ref_position_ids, dim=0).to(DEVICE).unsqueeze_(0)

            with torch.no_grad():
                packed_ref_logits = ref_model(
                    input_ids=ref_input_ids,
                    position_ids=ref_position_ids,
                    use_cache=False,
                    cu_seq_lens_q=ref_cu_seq_lens,
                    cu_seq_lens_k=ref_cu_seq_lens,
                    max_length_q=max(ref_num_tokens),
                    max_length_k=max(ref_num_tokens),
                    sequence_parallel_mesh=actor_sp_mesh,
                ).logits

            if args.token_level_rm is not None:
                # assert ref_num_tokens.sum() == ref_input_ids.numel() * 8, f"{ref_num_tokens}, {_num_pad}, {ref_input_ids.numel()}, {rl_input_ids.numel()}"
                packed_rm_logits = token_level_rm(
                    input_ids=ref_input_ids,
                    position_ids=ref_position_ids,
                    use_cache=False,
                    cu_seq_lens_q=ref_cu_seq_lens,
                    cu_seq_lens_k=ref_cu_seq_lens,
                    max_length_q=max(ref_num_tokens),
                    max_length_k=max(ref_num_tokens),
                    sequence_parallel_mesh=actor_sp_mesh,
                ).logits
                # use last token logits as reward logits
                packed_rm_logits = packed_rm_logits[:, :, 0]  # TODO: replace with auto path rm
                if token_level_rm.tp_mesh.size() > 1:
                    _packed_rm_logits = dist.nn.all_gather(packed_rm_logits, group=token_level_rm.tp_mesh.get_group())
                    packed_rm_logits = torch.cat(_packed_rm_logits, dim=1)
                if actor_sp_mesh and actor_sp_mesh.size() > 1:
                    _packed_rm_logits = dist.nn.all_gather(packed_rm_logits, group=actor_sp_mesh.get_group())
                    packed_rm_logits = torch.cat(_packed_rm_logits, dim=1)
                packed_rm_logits = packed_rm_logits[:, : actor_labels.size(1)]
                rm_logits = unpack_sequence(packed_rm_logits, ref_num_tokens, dim=1)

            # The labels of prefill tokens and last token are -100.
            # HACK: (for sp) The -100 part takes the value of 0,
            # this part will be masked later.
            packed_ref_logprobs = ref_model.gather_logprobs(packed_ref_logits, ref_labels.clip(0), actor_sp_mesh)
            ref_logprobs = unpack_sequence(packed_ref_logprobs, ref_num_tokens, dim=1)
            unpacked_labels = unpack_sequence(rl_labels, rl_num_tokens, dim=1)

            _positive_losses = []
            _negative_losses = []
            _kl_penalty_losses = []
            _token_level_rm_losses = []
            _losses = []
            for i in range(rl_batch_size):
                _judger_reward = judger_rewards[i]
                assert unpacked_labels[i].numel() == rl_num_tokens[i]
                # from the last prefill token, to the second-to-last token (excluding the eos token)
                _num_action_tokens = (unpacked_labels[i] >= 0).sum()

                _logprobs = logprobs[i][0, -_num_action_tokens - 1 : -1]
                _ref_logprobs = ref_logprobs[i][0, -_num_action_tokens - 1 : -1]

                _old_logprobs = _logprobs.detach()
                _judger_advantages = judger_advantages[i]

                if args.token_level_rm is not None:
                    # compute cumulative mean of rm scores
                    _rm_scores = rm_logits[i][0, -_num_action_tokens - 1 : -1]
                    _cum_mean_rm_scores = _rm_scores.cumsum(0).squeeze() / torch.arange(1, _num_action_tokens + 1).to(DEVICE)
                    _seq_mean_rm_scores = _rm_scores.mean()

                    # ----------token level rm loss (cross entropy)------------
                    _rm_label = torch.tensor([int(max(_judger_reward, 0))]).to(DEVICE)
                    _seq_mean_rm_scores = _seq_mean_rm_scores.reshape(_rm_label.shape)
                    _token_level_rm_loss = F.binary_cross_entropy_with_logits(_seq_mean_rm_scores.float(), _rm_label.float(), reduction="none")
                    _token_level_rm_loss = _token_level_rm_loss.sum() * actor_dp_size / global_num_seqs
                    _token_level_rm_losses.append(_token_level_rm_loss)

                    # use probability to reweight policy loss
                    _correct_prob = torch.sigmoid(_cum_mean_rm_scores).detach()
                    _incorrect_prob = 1 - _correct_prob

                    if args.get("threshold_rescale", False):
                        correct_threshold = args.get("correct_threshold", 0.5)
                        incorrect_threshold = args.get("incorrect_threshold", 0.5)
                        _pos_weight = threshold_rescale(_correct_prob, correct_threshold)
                        _neg_weight = threshold_rescale(_incorrect_prob, incorrect_threshold)
                    elif args.get("topk_rescale", False):
                        correct_topk_ratio = args.get("correct_topk_ratio", 0.5)
                        incorrect_topk_ratio = args.get("incorrect_topk_ratio", 0.5)
                        _pos_weight = topk_rescale(_correct_prob, correct_topk_ratio)
                        _neg_weight = topk_rescale(_incorrect_prob, incorrect_topk_ratio)
                    else:
                        raise NotImplementedError("Only support threshold_rescale and topk_rescale.")
                else:
                    _pos_weight, _neg_weight = 1.0, 1.0

                # ----------positive loss (behavior cloning)------------
                _positive_loss = sft_loss[i][0, -_num_action_tokens - 1 : -1]
                _positive_loss = (_positive_loss * _pos_weight).sum()
                if args.get("pos_mult_adv", False):
                    _positive_loss = _positive_loss * _judger_advantages
                if _judger_reward > 0:
                    _positive_loss = _positive_loss * actor_dp_size / global_positive_tokens * args.positive_loss_factor
                else:
                    # negative sample does not need sft loss
                    _positive_loss = torch.zeros_like(_positive_loss)
                _positive_losses.append(_positive_loss)

                # ----------negative loss (policy gradient)------------
                if _judger_reward > 0:
                    # positive sample, does not need policy loss
                    _negative_loss = torch.zeros_like(_positive_loss)
                    _kl_penalty_loss = torch.zeros_like(_positive_loss)
                    _negative_losses.append(_negative_loss)
                else:
                    _advantages = _judger_advantages * _neg_weight
                    _negative_loss = torch.exp(_logprobs - _old_logprobs.detach()) * _advantages
                    _negative_loss = -torch.sum(_negative_loss) * actor_dp_size / global_negative_tokens * args.negative_loss_factor
                    _negative_losses.append(_negative_loss)

                # ----------compute kl penalty------------
                assert _logprobs.ndim == 1
                kl_type = args.get("kl_type", "unbias")  # kl, unbias, mse
                if kl_type == "kl":
                    kl = _ref_logprobs - _logprobs
                    _kl_penalty_loss = (args.kl_coef * kl).sum() * actor_dp_size / global_action_tokens
                elif kl_type == "unbias":
                    kl = _ref_logprobs - _logprobs
                    nonneg_nobias_kl = torch.exp(kl) - kl - 1
                    _kl_penalty_loss = (args.kl_coef * nonneg_nobias_kl).sum() * actor_dp_size / global_action_tokens
                elif kl_type == "mse":
                    _kl_penalty_loss = (
                        (args.kl_coef * (_ref_logprobs - _logprobs).square() / 2).sum() * actor_dp_size / global_action_tokens
                    )
                _kl_penalty_losses.append(_kl_penalty_loss)

                # ----------compute total loss------------
                _loss = _positive_loss + _negative_loss + _kl_penalty_loss
                _losses.append(_loss)

                step_sum_gen_entropy += -_old_logprobs.sum().item()
                step_sum_ref_kl += (_old_logprobs - _ref_logprobs).sum().item()
                step_sum_adv += _judger_advantages.sum().item()
                step_action_tokens += _num_action_tokens.item()

            loss = sum(_losses)
            if update_actor:
                loss.backward()

            # for logging
            step_positive_loss += sum(_positive_losses).item()
            step_negative_loss += sum(_negative_losses).item()
            step_kl_penalty_loss += sum(_kl_penalty_losses).item()
            step_rl_loss += loss.item()
            step_rl_consumed_tokens += rl_num_tokens.sum() / actor_data_mesh.size()

            if args.token_level_rm is not None:
                token_level_rm_loss = sum(_token_level_rm_losses)
                token_level_rm_loss.backward()
                step_token_level_rm_loss += token_level_rm_loss.item()

        step_rl_time = time.time() - step_rl_start_t
        step_avg_ref_kl = step_sum_ref_kl / step_action_tokens
        step_avg_gen_entropy = step_sum_gen_entropy / step_action_tokens
        step_avg_adv = step_sum_adv / step_action_tokens

        actor_data_group = actor_data_mesh.get_group()
        step_avg_ref_kl = reduce_mean(step_avg_ref_kl, actor_data_group)
        step_avg_gen_entropy = reduce_mean(step_avg_gen_entropy, actor_data_group)
        step_avg_adv = reduce_mean(step_avg_adv, actor_data_group)
        step_avg_new_tokens = reduce_mean(step_avg_new_tokens, actor_data_group)

        if update_actor:
            actor_grad_norm = actor_model.clip_grad_norm(args.max_grad_norm)
            actor_grad_norm = actor_grad_norm.item()
            actor_optimizer.step()
            actor_optimizer.zero_grad()
        else:
            actor_grad_norm = 0

        if args.token_level_rm is not None:
            token_rm_grad_norm = token_level_rm.clip_grad_norm(args.max_grad_norm)
            token_rm_grad_norm = token_rm_grad_norm.item()
            token_rm_optimizer.step()
            token_rm_optimizer.zero_grad()

        step_time = time.time() - step_start_t
        eta = step_time * (total_steps - step)
        eta = timedelta(seconds=int(eta))

        infer_tgs = int(step_infer_consumed_tokens / step_infer_time)
        rl_tgs = int(step_rl_consumed_tokens / step_rl_time)

        actor_lr = cur_lr if update_actor else 0.0
        max_memory = DEVICE_MODULE.max_memory_allocated()
        log_dict = {
            "step": step + 1,
            "actor_lr": actor_lr,
            "actor_grad_norm": actor_grad_norm,
            "token_level_rm_lr": token_rm_cur_lr if args.token_level_rm is not None else 0.0,
            "token_rm_grad_norm": token_rm_grad_norm if args.token_level_rm is not None else 0.0,
            "avg_judger_reward": step_avg_judger_reward,
            "avg_adv": step_avg_adv,
            "avg_gen_entropy": step_avg_gen_entropy,
            "avg_ref_kl": step_avg_ref_kl,
            "positive_loss": step_positive_loss,
            "negative_loss": step_negative_loss,
            "kl_penalty_loss": step_kl_penalty_loss,
            "rl_loss": step_rl_loss,
            "token_level_rm_loss": step_token_level_rm_loss if args.token_level_rm is not None else 0.0,
            "max_memory": max_memory / 1024**3,
            "avg_new_tokens": step_avg_new_tokens,
            "num_rl_tokens": step_rl_consumed_tokens,
            "infer_tgs": infer_tgs,
            "rl_tgs": rl_tgs,
            "gen_time": step_gen_time,
            "infer_time": step_infer_time,
            "rl_time": step_rl_time,
            "total_time": step_time,
            "eta": eta.seconds,
        }
        for key, value in log_dict.items():
            if isinstance(value, torch.Tensor):
                log_dict[key] = value.item()
        with open(os.path.join(args.work_dir, f"rank{rank}.log.jsonl"), "a") as f:
            f.write(json.dumps(log_dict, ensure_ascii=False) + "\n")

        if is_interval(step, total_steps, args.log_interval):
            logger.info(
                "[Train] Step "
                f"{step + 1}/{total_steps}  "
                f"actor_lr: {cur_lr:.3e}  "
                f"actor_grad_norm: {actor_grad_norm:.3f}  "
                f"token_level_rm_lr: {token_rm_cur_lr if args.token_level_rm is not None else 0.0:.3e}  "
                f"token_rm_grad_norm: {token_rm_grad_norm if args.token_level_rm is not None else 0.0:.3f}  "
                f"avg_judger_reward: {step_avg_judger_reward:.8f}  "
                f"avg_adv: {step_avg_adv:.8f}  "
                f"avg_gen_entropy: {step_avg_gen_entropy:.3f}  "
                f"avg_ref_kl: {step_avg_ref_kl:.8f}  "
                f"positive_loss: {step_positive_loss:.3f}  "
                f"negative_loss: {step_negative_loss:.3f}  "
                f"kl_penalty_loss: {step_kl_penalty_loss:.3f}  "
                f"rl_loss: {step_rl_loss:.3f}  "
                f"token_level_rm_loss: {step_token_level_rm_loss if args.token_level_rm is not None else 0.0:.3f}  "
                f"kl_coef: {args.kl_coef:.5f}  "
                f"max_memory: {(max_memory / 1024**3):.1f}GB  "
                f"avg_new_tokens: {int(step_avg_new_tokens)}  "
                f"num_rl_tokens: {int(step_rl_consumed_tokens)}  "
                f"infer_tgs: {int(infer_tgs)}  "
                f"rl_tgs: {int(rl_tgs)}  "
                f"gen_time: {step_gen_time:.2f}s  "
                f"infer_time: {step_infer_time:.2f}s  "
                f"rl_time: {step_rl_time:.2f}s  "
                f"total_time: {step_time:.2f}s  "
                f"eta: {eta}"
            )

        if is_interval(step, total_steps, checkpoint_interval):
            DEVICE_MODULE.empty_cache()

            num_digits = len(str(abs(total_steps)))
            work_dir = args.work_dir
            ckpt_dir = os.path.join(work_dir, f"ckpt-{step+1:0{num_digits}}")
            hf_dir = os.path.join(work_dir, f"hf-{step+1:0{num_digits}}")

            with profile_time_and_memory("[Checkpoint]"):
                actor_model.save_pretrained(hf_dir)
                tokenizer.save_pretrained(hf_dir)

                dist.barrier()

    train_cost_time = time.time() - start_train_t
    logger.success(f"[Train] Cost {timedelta(seconds=int(train_cost_time))}")
    # ------------------------    Training  End  ---------------------------- #


if __name__ == "__main__":
    fire.Fire(train_oreal)
Download .txt
gitextract_3rj9la5a/

├── .gitignore
├── LICENSE
├── README.md
├── oreal/
│   ├── configs/
│   │   ├── oreal_w_tokenrm_DSR1-Distll-Qwen-7B_seqlen16k.py
│   │   ├── oreal_w_tokenrm_OREAL-32B-SFT_seqlen16k.py
│   │   ├── oreal_w_tokenrm_OREAL-7B-SFT_seqlen16k.py
│   │   └── oreal_wo_tokenrm_OREAL-7B-SFT_seqlen16k.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── prompt.py
│   │   └── trajectory.py
│   ├── judgers/
│   │   ├── __init__.py
│   │   ├── base_judger.py
│   │   ├── math_judger.py
│   │   ├── router.py
│   │   └── utils.py
│   └── utils.py
├── requirements.text
└── train_oreal.py
Download .txt
SYMBOL INDEX (81 symbols across 8 files)

FILE: oreal/datasets/prompt.py
  function load_hf_datasets (line 13) | def load_hf_datasets(repo, split="train"):
  function load_jsonl_datasets (line 31) | def load_jsonl_datasets(file_path):
  function balance_difficulty_with_cfg (line 51) | def balance_difficulty_with_cfg(dataset, difficulty_balance_cfg):
  class OrealPromptDataset (line 65) | class OrealPromptDataset(Dataset):
    method __init__ (line 66) | def __init__(self, path, tokenizer, difficulty_balance_cfg=None):
    method __len__ (line 80) | def __len__(self):
    method __getitem__ (line 83) | def __getitem__(self, idx):
  class PromptCollator (line 94) | class PromptCollator:
    method __init__ (line 96) | def __init__(self, pad_token_id=0, ignore_id=-100, pack_batch=False):
    method __call__ (line 101) | def __call__(self, instances):

FILE: oreal/datasets/trajectory.py
  class InferDataset (line 13) | class InferDataset(torch.utils.data.Dataset):
    method __init__ (line 15) | def __init__(self, prompts_input_ids, responses_ids, message_data, met...
    method __len__ (line 29) | def __len__(self):
    method __getitem__ (line 32) | def __getitem__(self, item):
  class TrajectoryDataset (line 50) | class TrajectoryDataset(torch.utils.data.Dataset):
    method __init__ (line 52) | def __init__(self):
    method num_action_tokens (line 59) | def num_action_tokens(self):
    method num_total_tokens (line 63) | def num_total_tokens(self):
    method update (line 66) | def update(self, trajectories):
    method dump_jsonl (line 79) | def dump_jsonl(self, path, tokenizer, debug=False):
    method dump_log (line 101) | def dump_log(self, path, tokenizer, debug=False):
    method __len__ (line 111) | def __len__(self):
    method __getitem__ (line 114) | def __getitem__(self, item):
  class TrajectoryDatasetWithFilter (line 119) | class TrajectoryDatasetWithFilter(TrajectoryDataset):
    method __init__ (line 120) | def __init__(self, repeat_k=1, only_keep_1_pair=True):
    method update (line 125) | def update(self, trajectories):
  class TrajectoryCollator (line 157) | class TrajectoryCollator(SftCollator):
    method __call__ (line 159) | def __call__(self, instances):

FILE: oreal/judgers/base_judger.py
  class JudgeStatus (line 22) | class JudgeStatus(Generic[T]):
  class BaseJudger (line 28) | class BaseJudger(ABC):
    method __init__ (line 29) | def __init__(self):
    method on_data_received (line 33) | def on_data_received(
    method on_reward_required (line 42) | def on_reward_required(
  function register_judger (line 53) | def register_judger(name: str):

FILE: oreal/judgers/math_judger.py
  class MathJudger (line 14) | class MathJudger(BaseJudger):
    method __init__ (line 50) | def __init__(
    method on_data_received (line 71) | def on_data_received(
    method on_reward_required (line 131) | def on_reward_required(
    method _evaluate_answer_with_llm (line 140) | def _evaluate_answer_with_llm(
    method _verify_from_string (line 184) | def _verify_from_string(self, verification: str):
    method _extract_and_verify_with_logic (line 193) | def _extract_and_verify_with_logic(

FILE: oreal/judgers/router.py
  class InputData (line 38) | class InputData(TypedDict):
  class GenericTask (line 48) | class GenericTask(Generic[T]):
  class SubprocessConfig (line 56) | class SubprocessConfig:
  class ParallelRouter (line 61) | class ParallelRouter:
    method __init__ (line 62) | def __init__(
    method submit (line 167) | def submit(self, data_batch: List[InputData]):
    method query (line 232) | def query(
    method _safe_process_worker (line 264) | def _safe_process_worker(
    method _process_worker (line 287) | def _process_worker(
    method _build_judger (line 427) | def _build_judger(judger_name: str, judger_conf: dict):
    method _try_catch_subprocess_exceptions (line 442) | def _try_catch_subprocess_exceptions(self):
    method shutdown (line 463) | def shutdown(self, timeout: float = 2.0):

FILE: oreal/judgers/utils.py
  function extract_answer (line 15) | def extract_answer(pred_str: str, execute: bool = False) -> str:
  function _fix_fracs (line 81) | def _fix_fracs(string):
  function _fix_a_slash_b (line 113) | def _fix_a_slash_b(string):
  function _fix_sqrt (line 130) | def _fix_sqrt(string):
  function strip_string (line 135) | def strip_string(string):
  function last_boxed_only_string (line 236) | def last_boxed_only_string(string):
  function extract_answer (line 264) | def extract_answer(pred_str: str, execute: bool = False) -> str:
  function is_digit (line 310) | def is_digit(s):
  function math_equal (line 318) | def math_equal(
  function math_equal_process (line 411) | def math_equal_process(param):
  function math_equal_process_v2 (line 415) | def math_equal_process_v2(param):
  function symbolic_equal (line 421) | def symbolic_equal(a, b):
  function symbolic_equal_process (line 448) | def symbolic_equal_process(a, b, output_queue):
  function call_with_timeout (line 453) | def call_with_timeout(func, *args, timeout=1, **kwargs):
  function math_majority_vote (line 468) | def math_majority_vote(answers: list, majority: Optional[int] = None):

FILE: oreal/utils.py
  class ConfigDict (line 6) | class ConfigDict(dict):
    method __getattr__ (line 8) | def __getattr__(self, item):
    method __setattr__ (line 13) | def __setattr__(self, key, value):
  class Config (line 17) | class Config:
    method fromfile (line 20) | def fromfile(file_path):

FILE: train_oreal.py
  class RLParallelSampler (line 55) | class RLParallelSampler(ParallelSampler):
    method __iter__ (line 56) | def __iter__(self):
  function log_format (line 81) | def log_format(rank, debug=False):
  function is_interval (line 95) | def is_interval(step, total_steps, interval):
  function reduce_mean (line 99) | def reduce_mean(data, group):
  function threshold_rescale (line 105) | def threshold_rescale(prob, threshold=0.5):
  function topk_rescale (line 112) | def topk_rescale(prob, topk_ratio=0.5):
  function train_oreal (line 124) | def train_oreal(cfg_path, **kwargs):
Condensed preview — 18 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (155K chars).
[
  {
    "path": ".gitignore",
    "chars": 1347,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.\n\nsrc/\n\n# C extensions\n*.so\n\n# Distribution / packag"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 6746,
    "preview": "# OREAL: Exploring the Limit of Outcome Reward for Learning Mathematical Reasoning\n\n\n[![license](https://img.shields.io/"
  },
  {
    "path": "oreal/configs/oreal_w_tokenrm_DSR1-Distll-Qwen-7B_seqlen16k.py",
    "chars": 4533,
    "preview": "# Model Related Settings\nactor = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\"\nreference = actor\ntoken_level_rm = actor\n\n# T"
  },
  {
    "path": "oreal/configs/oreal_w_tokenrm_OREAL-32B-SFT_seqlen16k.py",
    "chars": 7459,
    "preview": "# Model Related Settings\nactor = 'internlm/OREAL-32B-SFT'\nreference = actor\ntoken_level_rm = actor\n\n# Tokenizer related "
  },
  {
    "path": "oreal/configs/oreal_w_tokenrm_OREAL-7B-SFT_seqlen16k.py",
    "chars": 7458,
    "preview": "# Model Related Settings\nactor = \"internlm/OREAL-7B-SFT\"\nreference = actor\ntoken_level_rm = actor\n\n# Tokenizer related s"
  },
  {
    "path": "oreal/configs/oreal_wo_tokenrm_OREAL-7B-SFT_seqlen16k.py",
    "chars": 7310,
    "preview": "# Model Related Settings\nactor = \"internlm/OREAL-7B-SFT\"\nreference = actor\ntoken_level_rm = None\n\n# Tokenizer related se"
  },
  {
    "path": "oreal/datasets/__init__.py",
    "chars": 397,
    "preview": "# Copyright (c) InternLM. All rights reserved.\nfrom .prompt import OrealPromptDataset, PromptCollator\nfrom .trajectory i"
  },
  {
    "path": "oreal/datasets/prompt.py",
    "chars": 6224,
    "preview": "# Copyright (c) InternLM. All rights reserved.\nimport json\n\nimport torch\nfrom datasets import load_dataset\nfrom torch.nn"
  },
  {
    "path": "oreal/datasets/trajectory.py",
    "chars": 5964,
    "preview": "# Copyright (c) InternLM. All rights reserved.\nimport json\nimport random\n\nimport numpy as np\nimport torch\nfrom xtuner._l"
  },
  {
    "path": "oreal/judgers/__init__.py",
    "chars": 358,
    "preview": "# Copyright (c) InternLM. All rights reserved.\nfrom .base_judger import (\n    BaseJudger,\n    register_judger,\n    regis"
  },
  {
    "path": "oreal/judgers/base_judger.py",
    "chars": 1340,
    "preview": "# Copyright (c) InternLM. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfro"
  },
  {
    "path": "oreal/judgers/math_judger.py",
    "chars": 8662,
    "preview": "# Copyright (c) InternLM. All rights reserved.\nimport random\nimport re\nimport time\nfrom typing import List, Optional, Tu"
  },
  {
    "path": "oreal/judgers/router.py",
    "chars": 17913,
    "preview": "# Copyright (c) InternLM. All rights reserved.\nimport atexit\nimport functools\nimport os\nimport queue\nimport time\nimport "
  },
  {
    "path": "oreal/judgers/utils.py",
    "chars": 14227,
    "preview": "# flake8: noqa\n# isort: skip_file\n\nimport multiprocessing\nimport re\nfrom math import isclose\nfrom typing import Optional"
  },
  {
    "path": "oreal/utils.py",
    "chars": 1781,
    "preview": "import importlib.util\nimport os\nimport types\n\n\nclass ConfigDict(dict):\n\n    def __getattr__(self, item):\n        if item"
  },
  {
    "path": "requirements.text",
    "chars": 51,
    "preview": "fire\nflash-attn\ntorch>=2.5.0\nxtuner[all]==0.2.0rc0\n"
  },
  {
    "path": "train_oreal.py",
    "chars": 44330,
    "preview": "# Copyright (c) InternLM. All rights reserved.\nimport json\nimport os\nimport sys\nimport time\nfrom collections import Orde"
  }
]

About this extraction

This page contains the full source code of the InternLM/OREAL GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 18 files (144.0 KB), approximately 35.1k tokens, and a symbol index with 81 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!