Full Code of jqtangust/Robust-R1 for AI

main 02f2c9d3e785 cached
45 files
300.6 KB
75.3k tokens
186 symbols
1 requests
Download .txt
Showing preview only (317K chars total). Download the full file or copy to clipboard to get everything.
Repository: jqtangust/Robust-R1
Branch: main
Commit: 02f2c9d3e785
Files: 45
Total size: 300.6 KB

Directory structure:
gitextract_8_p77n13/

├── .gitignore
├── README.md
├── add_degradation/
│   ├── add_degradation.py
│   └── generate_degradation.py
├── app.py
├── demo.py
├── requirements.txt
├── run_scripts/
│   └── run_grpo_robust.sh
├── setup.sh
└── src/
    ├── eval/
    │   ├── test_od_r1.py
    │   ├── test_rec_baseline.py
    │   ├── test_rec_r1.py
    │   └── test_rec_r1_internvl.py
    └── open-r1-multimodal/
        ├── .gitignore
        ├── LICENSE
        ├── Makefile
        ├── configs/
        │   ├── ddp.yaml
        │   ├── zero2.yaml
        │   └── zero3.yaml
        ├── local_scripts/
        │   ├── zero2.json
        │   ├── zero3.json
        │   ├── zero3.yaml
        │   ├── zero3_offload.json
        │   └── zero_stage2_config.json
        ├── setup.cfg
        ├── setup.py
        └── src/
            └── open_r1/
                ├── __init__.py
                ├── configs.py
                ├── evaluate.py
                ├── generate.py
                ├── grpo_jsonl.py
                ├── qwen2_5vl_monkey_patch.py
                ├── trainer/
                │   ├── __init__.py
                │   ├── grpo_config.py
                │   └── grpo_trainer.py
                ├── utils/
                │   ├── __init__.py
                │   ├── callbacks.py
                │   ├── evaluation.py
                │   ├── hub.py
                │   ├── math.py
                │   └── pycocotools/
                │       ├── coco.py
                │       └── cocoeval.py
                └── vlm_modules/
                    ├── __init__.py
                    ├── qwen_module.py
                    └── vlm_module.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
Pipfile.lock

# poetry
poetry.lock

# pdm
.pdm.toml

# PEP 582
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
.conda/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# IDEs
.vscode/
.idea/
*.swp
*.swo
*~
.DS_Store
*.sublime-project
*.sublime-workspace

# Model checkpoints and outputs
checkpoints/
outputs/
runs/
*.pt
*.pth
*.ckpt
*.safetensors
*.bin
*.h5
*.hdf5
*.onnx
*.pb
*.tflite
*.pkl
*.pickle

# Training logs and outputs
logs/
wandb/
tensorboard/
tb_logs/
*.log
*.out
*.err

# Data files (uncomment if you don't want to track large data files)
# data/
# datasets/
# *.csv
# *.json
# *.jsonl
# *.parquet
# *.arrow

# Image files (uncomment if you don't want to track large image datasets)
# *.jpg
# *.jpeg
# *.gif
# *.bmp
# *.tiff
# *.webp
# Allow PNG files in assets directory
!assets/**/*.png
!assets/**/*.jpg
!assets/**/*.jpeg

# Temporary files
*.tmp
*.temp
*.bak
*.swp
*.swo
*~

# OS files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
desktop.ini

# DeepSpeed
deepspeed_logs/
deepspeed_info/

# Hugging Face cache
.cache/
.huggingface/
transformers_cache/

# Local configuration files
config.local.yaml
config.local.yml
*.local.*

# Evaluation results
eval_results/
results/

# JSON files (but keep config files in specific directories)
*.json
!**/local_scripts/*.json
!**/config*.json
!package.json
!tsconfig.json
!requirements.txt

# Large files
*.zip
*.tar
*.tar.gz
*.rar
*.7z

# Node modules (if any)
node_modules/

# Compiled models
*.model
*.weights



================================================
FILE: README.md
================================================
<div align="center">

# [AAAI 2026 Oral] Robust-R1: Degradation-Aware Reasoning for Robust Visual Understanding
This is the official repository for Robust-R1.

[Jiaqi Tang^](https://jqt.me/), 
[Jianmin Chen^](https://github.com/Ch921-cell), 
\
[Wei Wei**](https://scholar.google.com/citations?hl=zh-CN&user=v8KMYlwAAAAJ), 
[Xiaogang Xu](https://xuxiaogang.com/), 
[Runtao Liu](https://scholar.google.com/citations?hl=zh-CN&user=YHTvXF4AAAAJ), 
[Xiangyu Wu](https://scholar.google.com/citations?user=R0GjVWIAAAAJ&hl=en), 
[Qipeng Xie](), 
[Jiafei Wu](), 
[Lei Zhang](https://scholar.google.com/citations?hl=zh-CN&user=0Kg6Gi4AAAAJ) and 
\
[Qifeng Chen*](https://cqf.io)

^: Equal contribution. *: Corresponding Author. **: Co-corresponding Author.

[![Paper](https://img.shields.io/badge/cs.CV-Paper-b31b1b?style=flat&logo=arxiv&logoColor=white)](https://huggingface.co/papers/2512.17532)
[![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-ffd21e)](https://huggingface.co/Jiaqi-hkust/Robust-R1)
[![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Data-ffd21e)](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1)
[![made-for-VSCode](https://img.shields.io/badge/Made%20for-VSCode-1f425f.svg)](https://code.visualstudio.com/)
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)

</div>

## 📰 **News**
- **[2025-12-23]** 🔥 Online demo is now available at [HF Space](https://huggingface.co/spaces/Jiaqi-hkust/Robust-R1).
- **[2025-12-23]** 🔥 We release the [Code](https://github.com/jqtangust/Robust-R1), [Models](https://huggingface.co/Jiaqi-hkust/Robust-R1), and [Dataset](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) on HuggingFace.
- **[2025-12-22]** ✅ Our paper is now available on [arXiv](https://arxiv.org/abs/your-paper-id).
- **[2025-11-08]** 🚀 Our paper is accepted by **AAAI 2026 Oral**.


---

## 🔭 **Motivation**

- 🚩 **Limited Interpretability**: Lack of explicit mechanisms to diagnose degradation impacts on original semantic information.
- 🚩 **Isolated Optimization**: Neglect of the degradation propagation relation between the visual encoder and large language model.

<div align="center">
  <img src="assets/moti.png" width="85%" alt="Method Overview">
  <br>
</div>

---

## 🛠️ **Installation**

- **Clone the repository:**
   ```bash
   git clone https://github.com/jqtangust/Robust-R1.git
   cd Robust-R1
   ```

- **Create environment:**
   ```bash
   conda create -n robust_r1 python=3.10
   conda activate robust_r1
   bash setup.sh
   ```
---

### 🏰 **Pretrained and Fine-tuned Model**

- The following checkpoints are utilized to run Robust-R1:

  | Checkpoint | Link | Note |
  |:---------:|:----:|:----:|
  | Qwen2.5-VL-Base | [link](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) | Used as initial weights for training. |
  | **Robust-R1-SFT** | [link](https://huggingface.co/Jiaqi-hkust/Robust-R1-SFT) | Fine-tuned on [Robust-R1 dataset](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) |
  | **Robust-R1-RL** | [link](https://huggingface.co/Jiaqi-hkust/Robust-R1-RL) | Fine-tuned with reinforcement learning on [Robust-R1 dataset](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) |

---

## ⏳ **Demo**

### 🖥️ CLI Demo

- Run the command-line demo with a question:

  ```bash
  # if you use local weight
  export MODEL_PATH="your_model_name_or_path"

  python demo.py "What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n"
  ```

### 🌐 GUI Demo

- Set the model path as an environment variable and run the demo:

  ```bash
  # if you use local weight
  export MODEL_PATH="your_model_name_or_path"
  
  python app.py
  ```

- The demo will be available at `http://localhost:7860` by default.

- GUI [Online Demo](https://huggingface.co/spaces/Jiaqi-hkust/Robust-R1). 

  <div align="center">
    <img src="assets/demo.png" alt="Robust-R1 Demo">
  </div>

---

## 🧠 **Training**

### 🎓 Supervised Fine-Tuning

We employ [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) for supervised fine-tuning of the base model.

1. Clone the repository and install required dependencies:

   ```bash
   git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
   cd LLaMA-Factory
   pip install -e ".[torch,metrics]"
   ```

2. Download the base model [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct).

3. Prepare the training data and configuration files:

   - Download the [Robust images](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) and unzip it.
   - Modify the configuration files in the `LLaMA-Factory/data` directory.

4. Configure the training YAML file with your local paths (model path, data path, output directory.).

5. Run the training command to train the SFT model:

   ```bash
   llamafactory-cli train examples/train_full/qwen2_5_vl_full_sft.yaml
   ```

### 🎓 Reinforcement Learning

1. Download [Robust images](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) and unzip it in `Robust-R1/dataset`.

2. Prepare the training data file (train.jsonl) and organize the image folders.

3. Download the SFT model checkpoint from [Robust-R1-SFT](https://huggingface.co/Jiaqi-hkust/Robust-R1-SFT) or use your own trained SFT model.

4. Replace the following part in the [run_scripts/run_grpo_robust.sh](run_scripts/run_grpo_robust.sh) file with your own paths:

   ```bash
   data_paths="Robust-R1/data/train.jsonl" 
   image_folders="Robust-R1/data/train_images"
   model_path="your_model_name_or_path"
   ```

5. Run the script:

   ```bash
   bash run_scripts/run_grpo_robust.sh
   ```

---

## 📊 **Evaluation**

We use [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) for anti-degradation evaluation.

1. Clone the VLMEvalKit repository and install dependencies:

   ```bash
   git clone https://github.com/open-compass/VLMEvalKit.git
   cd VLMEvalKit
   pip install -e .
   ```

2. Prepare the evaluation datasets according to VLMEvalKit requirements.

3. **Image Degradation Pipeline**: Generate corrupted images for robustness evaluation.

   We provide an image degradation pipeline for generating corrupted images to evaluate model robustness.

   Navigate to the degradation pipeline directory and process images:

   ```bash
   cd add_degradation
   python generate_pipeline_open_source.py --input_dir <input_dir> --output_base_dir <output_base_dir> --dataset_name <dataset_name> --verbose
   ```

   The script will generate three output directories with different degradation intensities for each image.

4. Configure the model path and evaluation settings in the VLMEvalKit configuration file.

5. Run the evaluation command:

   ```bash
   python run.py --model <your_model_name_or_path> --data <dataset_name>
   ```

### 🔬 R-Bench Evaluation

For R-Bench evaluation, we use [R-Bench](https://github.com/Q-Future/R-Bench) to assess model performance under real-world corruptions.

1. Clone the R-Bench repository:

   ```bash
   git clone https://github.com/Q-Future/R-Bench.git
   ```

2. Evaluate using VLMEvalKit with R-Bench dataset:

   ```bash
   cd VLMEvalKit
   python run.py --data R-Bench-Dis --model <your_model_name_or_path> --verbose
   ```

3. For full dataset evaluation, follow the R-Bench pipeline as described in the [R-Bench repository](https://github.com/Q-Future/R-Bench).

---

## ⭐️ Citation

If you find Robust-R1 useful for your research and applications, please cite using this BibTeX:
   ``` latex
   @inproceedings{tang2025robustr1,
     title={Robust-R1: Degradation-Aware Reasoning for Robust Visual Understanding},
     author={Tang, Jiaqi and Chen, Jianmin and Wei, Wei and Xu, Xiaogang and Liu, Runtao and Wu, Xiangyu and Xie, Qipeng and Wu, Jiafei and Zhang, Lei and Chen, Qifeng},
     booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
     year={2026}
   }
   ```

## 🤝 Acknowledgements
The work described in this paper was supported by a grant from the Research Grants Council of the Hong Kong Special Administrative Region, China (Project Reference Number: AoE/E-601/24-N).

We also thank the authors of [VLM-R1](https://github.com/om-ai-lab/VLM-R1?tab=readme-ov-file), [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), and [R-Bench](https://github.com/Q-Future/R-Bench) for their contributions.



================================================
FILE: add_degradation/add_degradation.py
================================================
import cv2
import numpy as np
import random


def motion_blur(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    degree = max(5, int(intensity * 30))
    angle = random.uniform(0, 360)
    
    M = cv2.getRotationMatrix2D((degree/2, degree/2), angle, 1)
    kernel = np.diag(np.ones(degree))
    kernel = cv2.warpAffine(kernel, M, (degree, degree))
    kernel /= np.sum(kernel)
    
    return cv2.filter2D(img, -1, kernel)


def lens_blur(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    kernel_size = int(3 + intensity * 300) | 1
    sigma = intensity * 20
    
    kernel = cv2.getGaussianKernel(kernel_size, sigma)
    kernel = kernel @ kernel.T
    
    blurred = np.zeros_like(img, dtype=np.float32)
    for c in range(3):
        blurred[..., c] = cv2.filter2D(img[..., c].astype(np.float32), -1, kernel)
    
    result = cv2.addWeighted(
        img, 1 - intensity * 0.7,
        blurred.astype(np.uint8), intensity * 0.9, 0
    )
    
    return result


def gaussian_noise(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    noise_std = intensity * 75
    noise = np.random.normal(0, noise_std, img.shape)
    result = np.clip(img.astype(np.float32) + noise, 0, 255).astype(np.uint8)
    
    return result


def block_exchange(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    h, w = img.shape[:2]
    block_size = min(32, int(5 + intensity * 30))
    noisy_img = img.copy()
    
    num_exchanges = int(intensity * 35)
    for _ in range(num_exchanges):
        i1 = random.randint(0, h // block_size - 1)
        j1 = random.randint(0, w // block_size - 1)
        i2 = random.randint(0, h // block_size - 1)
        j2 = random.randint(0, w // block_size - 1)
        
        y1, x1 = i1 * block_size, j1 * block_size
        y2, x2 = i2 * block_size, j2 * block_size
        
        block1 = noisy_img[y1:y1+block_size, x1:x1+block_size].copy()
        noisy_img[y1:y1+block_size, x1:x1+block_size] = \
            noisy_img[y2:y2+block_size, x2:x2+block_size]
        noisy_img[y2:y2+block_size, x2:x2+block_size] = block1
    
    return noisy_img


def jpeg_compression(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if not 0 <= intensity <= 1:
        raise ValueError("Intensity must be in range [0.0, 1.0]")
    
    if img is None:
        raise ValueError("Input image is None")
    
    quality = int(100 - intensity * 95)
    quality = max(5, min(100, quality))
    
    encode_params = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
    _, encimg = cv2.imencode('.jpg', img, encode_params)
    compressed_img = cv2.imdecode(encimg, cv2.IMREAD_COLOR)
    
    return compressed_img


def mean_shift(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    spatial_radius = int(intensity * 40)
    color_radius = int(intensity * 40)
    
    return cv2.pyrMeanShiftFiltering(img, spatial_radius, color_radius)


def color_diffusion(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    kernel_size = 3 + 2 * int(intensity * 20)
    sigma = intensity * 50
    
    kernel = cv2.getGaussianKernel(kernel_size, sigma)
    kernel = kernel @ kernel.T * (intensity ** 2)
    
    diffused = np.zeros_like(img, dtype=np.float32)
    for c in range(3):
        diffused[..., c] = cv2.filter2D(img[..., c].astype(np.float32), -1, kernel)
    
    if intensity > 0.9:
        h, w = img.shape[:2]
        for _ in range(int(100 * intensity)):
            x, y = np.random.randint(0, w), np.random.randint(0, h)
            radius = np.random.randint(5, 20)
            cv2.circle(diffused, (x, y), radius,
                      (np.random.randint(0, 255),) * 3, -1)
    
    result = cv2.addWeighted(
        img, max(0.1, 1 - intensity * 0.9),
        diffused.astype(np.uint8), min(0.9, intensity * 0.9), 0
    )
    
    return np.clip(result, 0, 255).astype(np.uint8)


def sharpness_change(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    if intensity > 0:
        kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) * (intensity * 80)
        result = cv2.filter2D(img, -1, kernel)
    else:
        ksize = int(3 + abs(intensity) * 5) | 1
        result = cv2.GaussianBlur(img, (ksize, ksize), 0)
    
    result = cv2.addWeighted(img, 0.7, result, 0.3, 0)
    return result


def dark_illumination(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    result = (img * (1 - intensity ** 2)).clip(0, 255).astype(np.uint8)
    return result


def hsv_saturation(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
    hsv[..., 1] *= (1 - intensity)
    result = cv2.cvtColor(hsv.clip(0, 255).astype(np.uint8), cv2.COLOR_HSV2BGR)
    
    return result


def atmospheric_turbulence(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    h, w = img.shape[:2]
    x, y = np.meshgrid(np.arange(w), np.arange(h))
    distortion = intensity * 40 * np.sin(y / 30 + intensity * 5)
    x_new = np.clip(x + distortion, 0, w - 1).astype(np.float32)
    y_new = np.clip(y + distortion * 0.7, 0, h - 1).astype(np.float32)
    
    return cv2.remap(img, x_new, y_new, cv2.INTER_LINEAR)


def dirty_lens(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    h, w = img.shape[:2]
    dirt = np.zeros((h, w, 3), dtype=np.float32)
    
    if intensity > 0.1:
        for _ in range(int(10 * intensity)):
            center_x = random.randint(0, w)
            center_y = random.randint(0, h)
            cv2.ellipse(dirt, (center_x, center_y),
                       (random.randint(150, 300), random.randint(100, 200)),
                       angle=random.randint(0, 180),
                       startAngle=0, endAngle=360,
                       color=(50, 50, 50), thickness=-1)
    
    for _ in range(int(300 * intensity)):
        x = random.randint(0, w)
        y = random.randint(0, h)
        cv2.circle(dirt, (x, y), random.randint(4, 20),
                  (random.randint(50, 100),) * 3, -1)
    
    if intensity > 0.5:
        for _ in range(int(5 * intensity)):
            x = random.randint(0, w)
            y = random.randint(0, h)
            cv2.circle(dirt, (x, y), random.randint(20, 50),
                      (80, 80, 80), -1)
            cv2.circle(dirt, (x, y), random.randint(10, 30),
                      (120, 120, 120), -1)
    
    dirt = cv2.GaussianBlur(dirt, (0, 0), 30)
    dirt = dirt.astype(np.uint8)
    
    result = cv2.addWeighted(img, 1 - 0.7 * intensity, dirt, 0.8 * intensity, 0)
    return np.clip(result, 0, 255).astype(np.uint8)


def scan_lines(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    line_interval = max(3, int(20 / (intensity + 0.1)))
    line_width = max(5, int(7 * intensity))
    
    result = img.copy()
    for i in range(0, img.shape[0], line_interval):
        end_line = min(i + line_width, img.shape[0])
        result[i:end_line] = result[i:end_line] * 0.01
    
    return result


def graffiti(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if not 0 <= intensity <= 1:
        raise ValueError("Intensity must be in range [0.0, 1.0]")
    
    if img is None:
        raise ValueError("Input image is None")
    
    h, w = img.shape[:2]
    result = img.copy()
    
    for _ in range(int(10 * intensity)):
        color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
        pt1 = (random.randint(0, w - 1), random.randint(0, h - 1))
        pt2 = (random.randint(0, w - 1), random.randint(0, h - 1))
        thickness = random.randint(1, max(1, int(5 * intensity)))
        cv2.line(result, pt1, pt2, color, thickness)
    
    if intensity > 0.55:
        texts = ["X", "FAKE", "COPY", "VOID", "COPYRIGHT", str(random.randint(1, 100))]
        text = random.choice(texts)
        
        font_scale = max(0.5, intensity * 5)
        thickness = max(1, int(font_scale))
        
        text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0]
        text_width, text_height = text_size
        
        if h - 10 > text_height + 10:
            text_x = random.randint(0, max(1, w - text_width - 10))
            text_y = random.randint(text_height + 10, h - 10)
            
            cv2.putText(result, text,
                       (text_x, text_y),
                       cv2.FONT_HERSHEY_SIMPLEX,
                       font_scale,
                       (0, 0, 255),
                       thickness)
    
    return result


def watermark_damage(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    h, w = img.shape[:2]
    mask = np.zeros((h, w), dtype=np.float32)
    
    for _ in range(int(1 + intensity * 15)):
        x = random.randint(0, w - 50)
        y = random.randint(0, h - 50)
        cv2.rectangle(mask, (x, y),
                     (x + random.randint(50, 200), y + random.randint(20, 80)), 1, -1)
    
    repaired = cv2.inpaint(img, (mask * 255).astype(np.uint8), 3, cv2.INPAINT_TELEA)
    result = cv2.addWeighted(img, 1 - intensity, repaired, intensity, 0)
    
    if intensity > 0.5:
        edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150)
        result[edges > 0] = result[edges > 0] * 0.8
    
    return result


def lens_flare(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    
    h, w = img.shape[:2]
    flare = np.zeros((h, w, 3), dtype=np.float32)
    
    num_flares = 3 + int(30 * intensity)
    for _ in range(num_flares):
        x = random.randint(0, w)
        y = random.randint(0, h)
        radius = random.randint(10, 50)
        color = np.array([255, 255, 235])
        
        cv2.circle(flare, (x, y), radius, color.tolist(), -1)
        
        angle = random.uniform(0, 2 * np.pi)
        length = random.randint(30, 150)
        end_x = int(x + length * np.cos(angle))
        end_y = int(y + length * np.sin(angle))
        cv2.line(flare, (x, y), (end_x, end_y), color.tolist(), 2)
    
    flare = cv2.GaussianBlur(flare, (3, 3), 20 * intensity)
    
    result = cv2.addWeighted(img.astype(np.float32), 1, flare, 0.9 * intensity, 0)
    return np.clip(result, 0, 255).astype(np.uint8)


================================================
FILE: add_degradation/generate_degradation.py
================================================
import add_degradation
import cv2
import os
import numpy as np
import argparse

DEGRADATION_CONFIG = {
    'capture': {
        'lens_blur': {'weight': 20},
        'lens_flare': {'weight': 20},
        'motion_blur': {'weight': 20},
        'dirty_lens': {'weight': 20},
        'hsv_saturation': {'weight': 20}
    },
    'transmission': {
        'jpeg_compression': {'weight': 25},
        'block_exchange': {'weight': 25},
        'mean_shift': {'weight': 25},
        'scan_lines': {'weight': 25}
    },
    'environment': {
        'dark_illumination': {'weight': 25},
        'atmospheric_turbulence': {'weight': 25},
        'gaussian_noise': {'weight': 25},
        'color_diffusion': {'weight': 25}
    },
    'postprocessing': {
        'sharpness_change': {'weight': 33},
        'graffiti': {'weight': 33},
        'watermark_damage': {'weight': 34}
    }
}

def apply_degradation_Benchmark(image, method_name, intensity):
    degradation_func = getattr(add_degradation, method_name)
    degraded_img = degradation_func(image, intensity)
    return degraded_img

def main():
    parser = argparse.ArgumentParser(description='Image degradation pipeline for robustness evaluation')
    parser.add_argument('--input_dir', type=str, 
                       default=os.getenv('INPUT_DIR', './data/images'),
                       help='Input image directory path (can be set via INPUT_DIR environment variable)')
    parser.add_argument('--output_base_dir', type=str,
                       default=os.getenv('OUTPUT_BASE_DIR', './data/output'),
                       help='Base directory for output images (can be set via OUTPUT_BASE_DIR environment variable)')
    parser.add_argument('--dataset_name', type=str,
                       default=os.getenv('DATASET_NAME', 'RealWorldQA'),
                       help='Dataset name (used to generate output directory names)')
    
    args = parser.parse_args()
    
    folder_path = args.input_dir
    output_base_dir = args.output_base_dir
    dataset_name = args.dataset_name
    
    output_dirs = {
        0.9: os.path.join(output_base_dir, f'{dataset_name}_Robust_100'),
        0.45: os.path.join(output_base_dir, f'{dataset_name}_Robust_50'),
        0.23: os.path.join(output_base_dir, f'{dataset_name}_Robust_25')
    }

    if not os.path.exists(folder_path):
        raise ValueError(f"Input directory does not exist: {folder_path}")
    
    for path in output_dirs.values():
        os.makedirs(path, exist_ok=True)

    all_methods_with_weights = []
    for category, methods in DEGRADATION_CONFIG.items():
        for method_name, details in methods.items():
            all_methods_with_weights.append((method_name, details['weight']))

    method_names = [item[0] for item in all_methods_with_weights]
    weights = [item[1] for item in all_methods_with_weights]

    total_weight = sum(weights)
    probabilities = [w / total_weight for w in weights]

    num = 0
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            image_path = os.path.join(folder_path, filename)
            image = cv2.imread(image_path)
            
            if image is None:
                print(f"Warning: Could not read image {image_path}, skipping")
                num += 1
                continue

            selected_method_name = np.random.choice(method_names, p=probabilities)

            for intensity, output_dir in output_dirs.items():
                degraded_img = apply_degradation_Benchmark(image, selected_method_name, intensity)
                save_path = os.path.join(output_dir, filename)
                cv2.imwrite(save_path, degraded_img)
            
            num += 1
            if num % 100 == 0:
                print(f"Processed {num} images")
        
    print("Processing completed!")

if __name__ == '__main__':
    main()


================================================
FILE: app.py
================================================
import gradio as gr
import os
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import html

sys_prompt = """First output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags, 
        and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags, 
        then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags, 
        and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
        provides the user with the answer briefly in <ANSWER> <ANSWER_END>."""

project_dir = os.path.dirname(os.path.abspath(__file__))
temp_dir = os.path.join(project_dir, ".gradio_temp")
os.makedirs(temp_dir, exist_ok=True)
os.environ["GRADIO_TEMP_DIR"] = temp_dir

MODEL_PATH = os.getenv("MODEL_PATH", "")

if not MODEL_PATH:
    raise ValueError("MODEL_PATH environment variable must be set. Please set it to your model path.")

print(f"==========================================")
print(f"Initializing application...")
print(f"==========================================")

class ModelHandler:
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = None
        self.processor = None
        self._load_model()

    def _load_model(self):
        try:
            print(f"⏳ Loading model weights, this may take a few minutes...")
            
            self.processor = AutoProcessor.from_pretrained(self.model_path)
            
            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                self.model_path,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                attn_implementation="flash_attention_2" if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else "eager"
            )
            print("✅ Model loaded successfully!")
        except Exception as e:
            print(f"❌ Model loading failed: {e}")
            raise e

    def predict(self, message_dict, history, temperature, max_tokens):
        text = message_dict.get("text", "")
        files = message_dict.get("files", [])

        messages = []
        
        if history:
            print(f"Processing {len(history)} previous messages from history")
            for msg in history:
                role = msg.get("role", "")
                content = msg.get("content", "")
                
                if role == "user":
                    user_content = []
                    
                    if isinstance(content, list):
                        for item in content:
                            if isinstance(item, str):
                                if os.path.exists(item) or any(item.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']):
                                    user_content.append({"type": "image", "image": item})
                                else:
                                    user_content.append({"type": "text", "text": item})
                            elif isinstance(item, dict):
                                user_content.append(item)
                    elif isinstance(content, str):
                        if content:
                            user_content.append({"type": "text", "text": content})
                    
                    if user_content:
                        messages.append({"role": "user", "content": user_content})
                        
                elif role == "assistant":
                    if isinstance(content, str) and content:
                        messages.append({"role": "assistant", "content": content})
        
        current_content = []
        if files:
            for file_path in files:
                current_content.append({"type": "image", "image": file_path})
        
        if text:
            sys_prompt_formatted = " ".join(sys_prompt.split())
            full_text = f"{text}\n{sys_prompt_formatted}"
            current_content.append({"type": "text", "text": full_text})
        
        if current_content:
            messages.append({"role": "user", "content": current_content})
        
        print(f"Total messages for model: {len(messages)}")
        print(f"Message roles: {[m['role'] for m in messages]}")

        text_prompt = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        
        image_inputs, video_inputs = process_vision_info(messages)
        
        inputs = self.processor(
            text=[text_prompt],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt"
        )
        
        inputs = inputs.to(self.model.device)

        generation_kwargs = dict(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=True if temperature > 0 else False,
        )

        try:
            print("Starting model generation...")
            with torch.no_grad():
                generated_ids = self.model.generate(**generation_kwargs)
            
            input_length = inputs['input_ids'].shape[1]
            generated_ids = generated_ids[0][input_length:]
            
            print(f"Input length: {input_length}, Generated token count: {len(generated_ids)}")
            
            generated_text = self.processor.tokenizer.decode(
                generated_ids, 
                skip_special_tokens=True
            )
            
            print(f"Generation completed. Output length: {len(generated_text)}, Content preview: {repr(generated_text[:200])}")
            
            if generated_text and generated_text.strip():
                print(f"Yielding generated text: {generated_text[:100]}...")
                yield generated_text
            else:
                warning_msg = "⚠️ No output generated. The model may not have produced any response."
                print(warning_msg)
                yield warning_msg
                
        except Exception as e:
            import traceback
            error_details = traceback.format_exc()
            print(f"Error in model.generate: {error_details}")
            yield f"❌ Generation error: {str(e)}"
            return

model_handler = ModelHandler(MODEL_PATH)

def create_chat_ui():
    custom_css = """
    .gradio-container { font-family: 'Inter', sans-serif; }
    #chatbot { height: 650px !important; overflow-y: auto; }
    """

    with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="Robust-R1") as demo:
        
        with gr.Row():
            gr.Markdown("# 🤖Robust-R1:Degradation-Aware Reasoning for Robust Visual Understanding")

        with gr.Row():
            with gr.Column(scale=4):
                chatbot = gr.Chatbot(
                    elem_id="chatbot",
                    label="Chat",
                    type="messages",
                    avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
                    height=650
                )
                
                chat_input = gr.MultimodalTextbox(
                    interactive=True,
                    file_types=["image"],
                    placeholder="Enter your question or upload an image...",
                    show_label=False
                )

            with gr.Column(scale=1):
                with gr.Group():
                    gr.Markdown("### ⚙️ Generation Config")
                    temperature = gr.Slider(
                        minimum=0.01, maximum=1.0, value=0.6, step=0.05, 
                        label="Temperature"
                    )
                    max_tokens = gr.Slider(
                        minimum=128, maximum=4096, value=1024, step=128, 
                        label="Max New Tokens"
                    )
                
                clear_btn = gr.Button("🗑️ Clear Context", variant="stop")

        gr.Markdown("---")
        gr.Markdown("### 📚 Examples")
        gr.Markdown("Click the examples below to quickly fill the input box and start a conversation")
        
        example_images_dir = os.path.join(project_dir, "assets")
        
        examples_config = [
            ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", os.path.join(example_images_dir, "1.jpg")),
            ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", os.path.join(example_images_dir, "2.jpg")),
        ]
        
        example_data = []
        for text, img_path in examples_config:
            if os.path.exists(img_path):
                example_data.append({"text": text, "files": [img_path]})
        
        if example_data:
            gr.Examples(
                examples=example_data,
                inputs=chat_input,
                label="",
                examples_per_page=3
            )
        else:
            gr.Markdown("*No example images available, please manually upload images for testing*")
        
        async def respond(user_msg, history, temp, tokens):
            text = user_msg.get("text", "").strip()
            files = user_msg.get("files", [])
            user_content = list(files)
            if text: user_content.append(text)
            
            if not files and text: user_message = {"role": "user", "content": text}
            else: user_message = {"role": "user", "content": user_content}
            
            history.append(user_message)
            yield history, gr.MultimodalTextbox(value=None, interactive=False)

            history.append({"role": "assistant", "content": ""})
            
            try:
                previous_history = history[:-2] if len(history) >= 2 else []
                
                generated_text = ""
                for chunk in model_handler.predict(user_msg, previous_history, temp, tokens):
                    generated_text = chunk
                    
                    safe_text = html.escape(generated_text)
                    safe_text = generated_text.replace("<", "&lt;").replace(">", "&gt;")
                    
                    history[-1]["content"] = safe_text
                    yield history, gr.MultimodalTextbox(interactive=False)
                    
            except Exception as e:
                import traceback
                traceback.print_exc()
                history[-1]["content"] = f"❌ Inference error: {str(e)}"
                yield history, gr.MultimodalTextbox(interactive=True)
            
            yield history, gr.MultimodalTextbox(value=None, interactive=True)
            
        chat_input.submit(
            respond,
            inputs=[chat_input, chatbot, temperature, max_tokens],
            outputs=[chatbot, chat_input]
        )

        def clear_history(): return [], None
        clear_btn.click(clear_history, outputs=[chatbot, chat_input])

    return demo

if __name__ == "__main__":
    demo = create_chat_ui()
    
    print(f"🚀 Service is starting, please visit: http://localhost:7862")
    demo.launch(
        server_name="0.0.0.0",
        server_port=7862,
        share=False,
        show_error=True,
        allowed_paths=[project_dir]
    )


================================================
FILE: demo.py
================================================
#!/usr/bin/env python3
"""
CLI Demo for Robust-R1: Visual Question Answering with Degradation-Aware Reasoning.
"""

import os
import sys
import torch
import argparse
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# Default model path - can be overridden by MODEL_PATH environment variable
# Users can set MODEL_PATH to their local model path or HuggingFace model name
DEFAULT_MODEL_PATH = "Jiaqi-hkust/Robust-R1-RL"  # HuggingFace model name
MODEL_PATH = os.getenv("MODEL_PATH", DEFAULT_MODEL_PATH)

# Fixed image path for demo
FIXED_IMAGE_PATH = "assets/1.jpg"

SYS_PROMPT = """First output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags, 
and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags, 
then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags, 
and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
provides the user with the answer briefly in <ANSWER> <ANSWER_END>."""

DEFAULT_TEMPERATURE = 0.6
DEFAULT_MAX_TOKENS = 1024


class ModelHandler:
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = None
        self.processor = None
        self._load_model()

    def _load_model(self):
        try:
            print("Loading model, this may take a few minutes...")
            self.processor = AutoProcessor.from_pretrained(self.model_path)
            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                self.model_path,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                attn_implementation="flash_attention_2" if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else "eager"
            )
            print("Model loaded successfully!")
        except Exception as e:
            print(f"Model loading failed: {e}")
            raise e

    def predict(self, question, image_path, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS):
        """
        Generate response for the given question and image.
        
        Args:
            question: User question
            image_path: Path to the image
            temperature: Generation temperature
            max_tokens: Maximum number of tokens to generate
        
        Returns:
            Generated text response
        """
        sys_prompt_formatted = " ".join(SYS_PROMPT.split())
        full_text = f"{question}\n{sys_prompt_formatted}"
        
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": full_text},
                    {"type": "image", "image": image_path},
                ],
            }
        ]
        
        text_prompt = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        
        image_inputs, video_inputs = process_vision_info(messages)
        
        inputs = self.processor(
            text=[text_prompt],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt"
        )
        
        inputs = inputs.to(self.model.device)
        
        generation_kwargs = dict(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=True if temperature > 0 else False,
        )
        
        try:
            print("Generating response...")
            with torch.no_grad():
                generated_ids = self.model.generate(**generation_kwargs)
            
            input_length = inputs['input_ids'].shape[1]
            generated_ids = generated_ids[0][input_length:]
            
            generated_text = self.processor.tokenizer.decode(
                generated_ids, 
                skip_special_tokens=True
            )
            
            return generated_text
                
        except Exception as e:
            import traceback
            error_details = traceback.format_exc()
            print(f"Generation error: {error_details}")
            raise e


def main():
    parser = argparse.ArgumentParser(
        description="CLI Demo for Robust-R1: Visual Question Answering with Degradation-Aware Reasoning",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python demo.py "What type of vehicles are the people riding?"
    python demo.py "What is in the image?" --temperature 0.7 --max-tokens 2048
    python demo.py "Your question" --image /path/to/image.jpg
        """
    )
    
    parser.add_argument(
        "question",
        type=str,
        help="Question to ask about the image"
    )
    
    parser.add_argument(
        "--image", "-i",
        type=str,
        default=FIXED_IMAGE_PATH,
        help=f"Path to the input image (default: {FIXED_IMAGE_PATH})"
    )
    
    parser.add_argument(
        "--temperature", "-t",
        type=float,
        default=DEFAULT_TEMPERATURE,
        help=f"Generation temperature (default: {DEFAULT_TEMPERATURE})"
    )
    
    parser.add_argument(
        "--max-tokens", "-m",
        type=int,
        default=DEFAULT_MAX_TOKENS,
        help=f"Maximum number of tokens to generate (default: {DEFAULT_MAX_TOKENS})"
    )
    
    parser.add_argument(
        "--model-path",
        type=str,
        default=MODEL_PATH,
        help=f"Model path or HuggingFace model name (default: {MODEL_PATH}). Can also be set via MODEL_PATH environment variable."
    )
    
    args = parser.parse_args()
    
    if not os.path.exists(args.image):
        print(f"Error: Image file does not exist: {args.image}")
        sys.exit(1)
    
    print(f"Model path: {args.model_path}")
    print(f"Image path: {args.image}")
    print(f"Question: {args.question}")
    print(f"Temperature: {args.temperature}, Max tokens: {args.max_tokens}")
    print("-" * 80)
    
    model_handler = ModelHandler(args.model_path)
    
    try:
        response = model_handler.predict(
            question=args.question,
            image_path=args.image,
            temperature=args.temperature,
            max_tokens=args.max_tokens
        )
        
        print("\n" + "=" * 80)
        print("Model Response:")
        print("=" * 80)
        print(response)
        print("=" * 80)
        
    except KeyboardInterrupt:
        print("\n\nUser interrupted")
        sys.exit(0)
    except Exception as e:
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()


================================================
FILE: requirements.txt
================================================
torch>=2.5.1
transformers==4.49.0
gradio>=4.0.0
qwen-vl-utils>=0.0.1
accelerate>=1.2.1
sentencepiece>=0.1.99
pillow
safetensors>=0.3.3
huggingface-hub>=0.19.2,<1.0
einops>=0.8.0
packaging>=23.0
numpy>=1.21.0


================================================
FILE: run_scripts/run_grpo_robust.sh
================================================
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )"
export REPO_HOME="${PROJECT_ROOT}"
echo "REPO_HOME: $REPO_HOME"
# Change the data_paths and image_folders to your own data
data_paths="your_data_path" 
image_folders="your_images_folder"
model_path="your_model_name_or_path"
is_reward_customized_from_vlm_module=True
echo "data_paths: $data_paths"
echo "image_folders: $image_folders"


export EXP_NAME="your_experiment_name" # TODO: change this to your own experiment name
TASK_TYPE="robust"
cd ${REPO_HOME}/src/open-r1-multimodal

export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
# create the run directory and log file
mkdir -p ${REPO_HOME}/runs/${EXP_NAME}/log
export LOG_PATH="${REPO_HOME}/runs/${EXP_NAME}/log/debug_log.$(date +%Y-%m-%d-%H-%M-%S).txt"
# MAX_STEPS=1200 # TODO: change this to your own max steps

# export WANDB_DISABLED=true     
torchrun --nproc_per_node="8" \
    --nnodes="1" \
    --node_rank="0" \
    --master_addr="127.0.0.1" \
    --master_port="12352" \
  src/open_r1/grpo_jsonl.py \
    --use_vllm False \
    --output_dir ${REPO_HOME}/checkpoints/rl/${EXP_NAME} \
    --resume_from_checkpoint True \
    --model_name_or_path $model_path \
    --data_file_paths $data_paths \
    --image_folders $image_folders \
    --is_reward_customized_from_vlm_module $is_reward_customized_from_vlm_module \
    --task_type $TASK_TYPE \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 2\
    --gradient_checkpointing true \
    --logging_steps 1 \
    --num_train_epochs 1 \
    --bf16 \
    --attn_implementation flash_attention_2 \
    --run_name ${EXP_NAME} \
    --data_seed 42 \
    --save_steps 100 \
    --num_generations 8 \
    --max_completion_length 2048 \
    --reward_funcs accuracy format type length\
    --beta 0.04 \
    --report_to none \
    --dataset-name this_is_not_used \
    --deepspeed ${REPO_HOME}/src/open-r1-multimodal/local_scripts/zero3.json \
    --freeze_vision_modules true
    

echo "Training completed for ${EXP_NAME}"


================================================
FILE: setup.sh
================================================
# conda create -n vlm-r1 python=3.11 
# conda activate vlm-r1

# Install the packages in open-r1-multimodal .
cd src/open-r1-multimodal # We edit the grpo.py and grpo_trainer.py in open-r1 repo.

# Install torch first (required for flash-attn)
pip install torch>=2.5.1 torchvision

# Install open-r1 package with dev dependencies
pip install -e ".[dev]"

# Additional modules
pip install wandb==0.18.3
pip install tensorboardx
pip install qwen_vl_utils
pip install babel
pip install python-Levenshtein
pip install matplotlib
pip install pycocotools
pip install openai
pip install httpx[socks]

# Install flash-attn last (requires torch to be already installed)
pip install flash-attn --no-build-isolation

================================================
FILE: src/eval/test_od_r1.py
================================================
import re
import os
import json
import torch
import random

from tqdm import tqdm
from pprint import pprint
from qwen_vl_utils import process_vision_info
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor


def extract_bbox_answer(content):
    pattern = r'```json(.*?)```'
    json_match = re.search(pattern, content, re.DOTALL)
    bbox_json = json_match.group(1).strip() if json_match else None

    if bbox_json:
        try:
            bbox = json.loads(bbox_json)[0]['bbox_2d']
            return bbox, False
        except:
            return [0, 0, 0, 0], False
    else:
        return [0, 0, 0, 0], False


def iou(box1, box2):
    inter_x1 = max(box1[0], box2[0])
    inter_y1 = max(box1[1], box2[1])
    inter_x2 = min(box1[2] - 1, box2[2] - 1)
    inter_y2 = min(box1[3] - 1, box2[3] - 1)
    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
        inter = (inter_x2 - inter_x1 + 1) * (inter_y2 - inter_y1 + 1)
    else:
        inter = 0
    union = (box1[2] - box1[0]) * (box1[3] - box1[1]) + (box2[2] - box2[0]) * (box2[3] - box2[1]) - inter
    return float(inter) / union


def load_model(model_path, device_map):
    #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        device_map=device_map,
    )

    # default processer
    processor = AutoProcessor.from_pretrained(model_path)

    return model, processor


def eval_od_r1(
    model_path, test_datasets, data_root, image_root, question_template, output_dir, batch_size=32, sample_num=500, seed=42, device_map="cuda:0"
):
    random.seed(seed)
    model, processor = load_model(model_path, device_map)

    for ds in test_datasets:
        print(f"Processing {ds}...")

        ds_path = os.path.join(data_root, f"{ds}.json")
        data = json.load(open(ds_path, "r"))
        random.shuffle(data)
        data = data[:sample_num]
        messages = []

        for x in data:
            image_path = os.path.join(image_root, x['image'])
            messages.append(
                [
                    {
                        "role":
                            "user",
                        "content":
                            [
                                {
                                    "type": "image",
                                    "image": f"file://{image_path}"
                                }, {
                                    "type": "text",
                                    "text": question_template.format(Question=x['normal_caption'])
                                }
                            ]
                    }
                ]
            )

        all_outputs = []  # List to store all answers

        # Process data
        for i in tqdm(range(0, len(messages), batch_size)):
            batch_messages = messages[i:i + batch_size]

            # Preparation for inference
            text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]

            image_inputs, video_inputs = process_vision_info(batch_messages)
            inputs = processor(
                text=text,
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(device_map)

            # Inference: Generation of the output
            generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)

            generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
            batch_output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
            all_outputs.extend(batch_output_text)

        final_output = []
        correct_number = 0

        for input_example, model_output in zip(data, all_outputs):
            original_output = model_output
            ground_truth = input_example['solution']
            ground_truth_normalized = input_example['normalized_solution']
            model_answer, normalized = extract_bbox_answer(original_output)

            # Count correct answers
            correct = 0
            if model_answer is not None:
                iou_value = iou(model_answer, ground_truth_normalized if normalized else ground_truth)
                if iou_value > 0.5:
                    correct = 1
            correct_number += correct

            # Create a result dictionary for this example
            result = {
                "question": question_template.format(Question=input_example['normal_caption']),
                "ground_truth": ground_truth if not normalized else ground_truth_normalized,
                "model_output": original_output,
                "extracted_answer": model_answer,
                "correct": correct,
                "iou": iou_value
            }
            final_output.append(result)

        # Calculate and print accuracy
        accuracy = correct_number / len(data) * 100
        print(f"\nAccuracy of {ds}: {accuracy:.2f}%")

        # Save results to a JSON file
        result_path = os.path.join(output_dir, f"{os.path.basename(model_path)}", f"{ds}_od_r1.json")
        os.makedirs(os.path.dirname(result_path), exist_ok=True)
        with open(result_path, "w") as f:
            json.dump({"accuracy": accuracy, "results": final_output}, f, indent=2)

        print(f"Results saved to {result_path}")
        print('-' * 100)


if __name__ == "__main__":
    model_path = ''  # Add the path to the model
    data_root = ''  # Add the data root
    test_datasets = ['refcoco_val', 'refcocop_val', 'refcocog_val']  # modify the datasets
    image_root = ''  # Add the image root
    output_dir = 'logs'  # Add the output directory, default is logs
    device_map = 'cuda:0'  # select the device, default is cuda:0

    question_template = '{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format.'  # modify the question template which must contain {Question}, {Question} will be replaced by the caption

    eval_od_r1(
        model_path=model_path,
        data_root=data_root,
        test_datasets=test_datasets,
        image_root=image_root,
        question_template=question_template,
        output_dir=output_dir,
        device_map=device_map
    )


================================================
FILE: src/eval/test_rec_baseline.py
================================================
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import json
from tqdm import tqdm
import re
import os
from pprint import pprint
import random


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse

import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="transformers")

def setup_distributed():
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank) 
    
    dist.init_process_group(backend="nccl")
    
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    
    print(f"Process {rank}/{world_size} initialized on cuda:{local_rank}")
    return local_rank, world_size, rank

local_rank, world_size, rank = setup_distributed()
device = f"cuda:{local_rank}"

steps = 100
MODEL_PATH=f"/data10/shz/project/LLaMA-Factory/saves/qwen2_5_vl-3b/full/sft/checkpoint-{steps}" 
OUTPUT_PATH="./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_sft_{STEPS}.json"

# MODEL_PATH = "/data10/shz/ckpt/vlm-r1-related/Qwen2.5-VL-3B-Instruct"
# OUTPUT_PATH = "./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_baseline_{STEPS}.json"

BSZ=4
DATA_ROOT = "/data10/shz/dataset/rec/rec_jsons_processed"

TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
IMAGE_ROOT = "/data10/shz/dataset/coco"

# TEST_DATASETS = ['lisa_test']
# IMAGE_ROOT = "/data10/shz/dataset/lisa"

#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map={"": local_rank}, 
)

# default processer
processor = AutoProcessor.from_pretrained(MODEL_PATH)

def extract_bbox_answer(content):
    bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
    # bbox_pattern = r'\[(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+)\]'
    bbox_match = re.search(bbox_pattern, content)

    if bbox_match:
        bbox = [float(bbox_match.group(1)), float(bbox_match.group(2)), float(bbox_match.group(3)), float(bbox_match.group(4))]
        return bbox
    return [0, 0, 0, 0]

def iou(box1, box2):
    inter_x1 = max(box1[0], box2[0])
    inter_y1 = max(box1[1], box2[1])
    inter_x2 = min(box1[2]-1, box2[2]-1)
    inter_y2 = min(box1[3]-1, box2[3]-1)
    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
        inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
    else:
        inter = 0
    union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
    return float(inter)/union

num_samples = 2000
for ds in TEST_DATASETS:
    if rank == 0:
        print(f"Processing {ds}...")
    ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
    data = json.load(open(ds_path, "r"))
    random.seed(42)
    random.shuffle(data)
    data = data[:num_samples]
    # QUESTION_TEMPLATE = "{Question}" if steps > 0 else "{Question} Please provide the bounding box coordinate in JSON format."
    QUESTION_TEMPLATE = "{Question} Please provide the bounding box coordinate in JSON format."
    
    # Split data for distributed evaluation
    per_rank_data = len(data) // world_size
    start_idx = rank * per_rank_data
    end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
    rank_data = data[start_idx:end_idx]
    
    messages = []

    for x in rank_data:
        image_path = os.path.join(IMAGE_ROOT, x['image'])
        message = [
            # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
            {
            "role": "user",
            "content": [
                {
                    "type": "image", 
                    "image": f"file://{image_path}"
                },
                {
                    "type": "text",
                    "text": QUESTION_TEMPLATE.format(Question=x['problem'])
                }
            ]
        }]
        messages.append(message)

    rank_outputs = [] # List to store answers for this rank
    all_outputs = []  # List to store all answers

    # Process data
    for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
        batch_messages = messages[i:i + BSZ]
    
        # Preparation for inference
        text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
        
        image_inputs, video_inputs = process_vision_info(batch_messages)
        inputs = processor(
            text=text,
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            padding_side="left",
            return_tensors="pt",
        )
        inputs = inputs.to(device)

        # Inference: Generation of the output
        generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
        
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        batch_output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        
        rank_outputs.extend(batch_output_text)

    print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")

    # Gather all outputs from all ranks
    all_outputs = [None] * len(data)
    rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]

    gathered_results = [None] * world_size
    dist.all_gather_object(gathered_results, rank_results)
    
    assert gathered_results[-1][-1][0] == len(data) - 1

    # The main process will collect all results
    if rank == 0:
        for results in gathered_results:
            for idx, output in results:
                assert idx < len(all_outputs)
                all_outputs[idx] = output
        assert all_outputs[-1] is not None

        final_output = []
        correct_number = 0

        for input_example, model_output in zip(data, all_outputs):
            original_output = model_output
            ground_truth = input_example['solution']
            model_answer = extract_bbox_answer(original_output)
            
            # Count correct answers
            correct = 0
            if model_answer is not None:
                if iou(model_answer, ground_truth) > 0.5:
                    correct = 1
            correct_number += correct
            
            # Create a result dictionary for this example
            result = {
                'image': input_example['image'],
                'question': input_example['problem'],
                'ground_truth': ground_truth,
                'model_output': original_output,
                'extracted_answer': model_answer,
                'correct': correct
            }
            final_output.append(result)

        # Calculate and print accuracy
        accuracy = correct_number / len(data) * 100
        print(f"\nAccuracy of {ds}: {accuracy:.2f}%")

        # Save results to a JSON file
        output_path = OUTPUT_PATH.format(DATASET=ds, STEPS=steps)
        output_dir = os.path.dirname(output_path)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(output_path, "w") as f:
            json.dump({
                'accuracy': accuracy,
                'results': final_output
            }, f, indent=2)

        print(f"Results saved to {output_path}")
        print("-"*100)
    
    # Synchronize all processes
    dist.barrier()







================================================
FILE: src/eval/test_rec_r1.py
================================================
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import json
from tqdm import tqdm
import re
import os
from pprint import pprint
import random


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse

import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="transformers")

def setup_distributed():
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank) 
    
    dist.init_process_group(backend="nccl")
    
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    
    return local_rank, world_size, rank

local_rank, world_size, rank = setup_distributed()
device = f"cuda:{local_rank}"
print(f"Process {rank} using {device}")

main_rank = 0
steps = 100
if rank == main_rank:
    print("Steps: ", steps)

RUN_NAME = "Qwen2.5-VL-3B-Instruct-rec"

MODEL_PATH=f"/training/shz/project/vlm-r1/VLM-R1/checkpoints/rl/{RUN_NAME}/checkpoint-{steps}"
OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"

BSZ=2   
DATA_ROOT = "/training/shz/dataset/vlm-r1/rec_jsons_processed"

# TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
# IMAGE_ROOT = "/training/shz/dataset/coco"


TEST_DATASETS = ['lisa_test']
IMAGE_ROOT = "/training/shz/dataset/lisa"


#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map={"": local_rank}, 
)

# default processer
processor = AutoProcessor.from_pretrained(MODEL_PATH)

def extract_bbox_answer(content):
    # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    bbox_pattern = r'\{.*\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]\s*.*\}'
    content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
    if content_answer_match:
        content_answer = content_answer_match.group(1).strip()
        bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL)
        if bbox_match:
            bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
            return bbox
    return [0, 0, 0, 0]

def iou(box1, box2):
    inter_x1 = max(box1[0], box2[0])
    inter_y1 = max(box1[1], box2[1])
    inter_x2 = min(box1[2]-1, box2[2]-1)
    inter_y2 = min(box1[3]-1, box2[3]-1)
    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
        inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
    else:
        inter = 0
    union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
    return float(inter)/union

num_samples = 2000
for ds in TEST_DATASETS:
    if rank == 0:
        print(f"Processing {ds}...")
    ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
    data = json.load(open(ds_path, "r"))
    random.seed(42)
    random.shuffle(data)
    data = data[:num_samples]

    QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."

    # Split data for distributed evaluation
    per_rank_data = len(data) // world_size
    start_idx = rank * per_rank_data
    end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
    rank_data = data[start_idx:end_idx]

    messages = []

    for x in rank_data:
        image_path = os.path.join(IMAGE_ROOT, x['image'])
        message = [
            # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
            {
            "role": "user",
            "content": [
                {
                    "type": "image", 
                    "image": f"file://{image_path}"
                },
                {
                    "type": "text",
                    "text": QUESTION_TEMPLATE.format(Question=x['problem'])
                }
            ]
        }]
        messages.append(message)

    rank_outputs = [] # List to store answers for this rank
    all_outputs = []  # List to store all answers

    # Process data
    for i in tqdm(range(0, len(messages), BSZ), disable=rank != main_rank):
        batch_messages = messages[i:i + BSZ]
    
        # Preparation for inference
        text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
        
        image_inputs, video_inputs = process_vision_info(batch_messages)
        inputs = processor(
            text=text,
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            padding_side="left",
            return_tensors="pt",
        )
        inputs = inputs.to(device)

        # Inference: Generation of the output
        generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
        
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        batch_output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        
        rank_outputs.extend(batch_output_text)

    print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")

    # Gather all outputs from all ranks
    all_outputs = [None] * len(data)
    rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]

    gathered_results = [None] * world_size
    dist.all_gather_object(gathered_results, rank_results)
    
    assert gathered_results[-1][-1][0] == len(data) - 1

    # The main process will collect all results
    if rank == main_rank:
        for results in gathered_results:
            for idx, output in results:
                assert idx < len(all_outputs)
                all_outputs[idx] = output
        assert all_outputs[-1] is not None

        final_output = []
        correct_number = 0

        for input_example, model_output in zip(data, all_outputs):
            original_output = model_output
            ground_truth = input_example['solution']
            model_answer = extract_bbox_answer(original_output)
            
            # Count correct answers
            correct = 0
            if model_answer is not None:
                if iou(model_answer, ground_truth) > 0.5:
                    correct = 1
            correct_number += correct
            
            # Create a result dictionary for this example
            result = {
                'image': input_example['image'],
                'question': input_example['problem'],
                'ground_truth': ground_truth,
                'model_output': original_output,
                'extracted_answer': model_answer,
                'correct': correct
            }
            final_output.append(result)

        # Calculate and print accuracy
        accuracy = correct_number / len(data) * 100
        print(f"\nAccuracy of {ds}: {accuracy:.2f}%")

        # Save results to a JSON file
        output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
        output_dir = os.path.dirname(output_path)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(output_path, "w") as f:
            json.dump({
                'accuracy': accuracy,
                'results': final_output
            }, f, indent=2)

        print(f"Results saved to {output_path}")
        print("-"*100)

    # Synchronize all processes
    dist.barrier()







================================================
FILE: src/eval/test_rec_r1_internvl.py
================================================
import torch
import json
from tqdm import tqdm
import re
import os
from pprint import pprint
import random
from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM
from open_r1.vlm_modules.internvl_module import InvernVLModule

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="transformers")

def setup_distributed():
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank) 
    
    dist.init_process_group(backend="nccl")
    
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    
    return local_rank, world_size, rank

local_rank, world_size, rank = setup_distributed()
device = f"cuda:{local_rank}"
print(f"Process {rank} using {device}")

main_rank = 0
steps = 300
if rank == main_rank:
    print("Steps: ", steps)

RUN_NAME = "InternVL2_5-4B_MPO-rec"

MODEL_PATH=f"/training/shz/project/vlm-r1/VLM-R1/checkpoints/rl/{RUN_NAME}/checkpoint-{steps}" 
OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"

BSZ=4
DATA_ROOT = "/training/shz/dataset/vlm-r1/rec_jsons_internvl"

# TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
# IMAGE_ROOT = "/training/shz/dataset/coco"

TEST_DATASETS = ['lisa_test']
IMAGE_ROOT = "/training/shz/dataset/lisa"

random.seed(42)

vlm_module = InvernVLModule()

model = vlm_module.get_model_class(MODEL_PATH, {}).from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map={"": local_rank},
    trust_remote_code=True,
    use_flash_attn=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
vlm_module.post_model_init(model, tokenizer)


def extract_bbox_answer(content):
    # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
    content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
    if content_answer_match:
        content_answer = content_answer_match.group(1).strip()
        bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL)
        if bbox_match:
            bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
            return bbox
    return [0, 0, 0, 0]

def iou(box1, box2):
    inter_x1 = max(box1[0], box2[0])
    inter_y1 = max(box1[1], box2[1])
    inter_x2 = min(box1[2]-1, box2[2]-1)
    inter_y2 = min(box1[3]-1, box2[3]-1)
    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
        inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
    else:
        inter = 0
    union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
    return float(inter)/union

from PIL import Image
def process_vision_info(batch_messages):
    images = []
    for msg in batch_messages:
        image_path = msg[0]['content'][0]['image'].replace("file://", "")
        image = Image.open(image_path)
        images.append(image)
    return images


sample_num = 2000
tokenizer.max_anyres_num = 12
for ds in TEST_DATASETS:
    if rank == main_rank:
        print(f"Processing {ds}...")
    ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
    data = json.load(open(ds_path, "r"))
    random.seed(42)
    random.shuffle(data)
    data = data[:sample_num]
    QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."

    # Split data for distributed evaluation
    per_rank_data = len(data) // world_size
    start_idx = rank * per_rank_data
    end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
    rank_data = data[start_idx:end_idx]

    messages = []
    for x in rank_data:
        image_path = os.path.join(IMAGE_ROOT, x['image'])
        message = [
            # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
            {
            "role": "user",
            "content": [
                {
                    "type": "image", 
                    "image": f"file://{image_path}"
                },
                {
                    "type": "text",
                    "text": QUESTION_TEMPLATE.format(Question=x['problem'])
                }
            ]
        }]
        messages.append(message)
    
    rank_outputs = [] # List to store answers for this rank
    all_outputs = []  # List to store all answers

    # Process data
    for i in tqdm(range(0, len(messages), BSZ), disable=rank != main_rank):
        batch_messages = messages[i:i + BSZ]
        prompts = vlm_module.prepare_prompt(None, [{"prompt": msg} for msg in batch_messages])

        images = process_vision_info(batch_messages)

        model_inputs = vlm_module.prepare_model_inputs(tokenizer, prompts, images)
        model_inputs['pixel_values'] = model_inputs['pixel_values'].to(torch.bfloat16)
        model_inputs = model_inputs.to(device)

        outputs = model.generate(**{k:v for k,v in model_inputs.items() if k not in vlm_module.get_non_generate_params()}, max_new_tokens=1024, do_sample=False, pad_token_id=tokenizer.eos_token_id)
        batch_output_text = tokenizer.batch_decode(
            outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        rank_outputs.extend(batch_output_text)
    
    print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")

    # Gather all outputs from all ranks
    all_outputs = [None] * len(data)
    rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]

    gathered_results = [None] * world_size
    dist.all_gather_object(gathered_results, rank_results)
    
    assert gathered_results[-1][-1][0] == len(data) - 1

    # The main process will collect all results
    if rank == main_rank:
        for results in gathered_results:
            for idx, output in results:
                assert idx < len(all_outputs)
                all_outputs[idx] = output
        assert all_outputs[-1] is not None

        final_output = []
        correct_number = 0

        for input_example, model_output in zip(data, all_outputs):
            original_output = model_output
            ground_truth = input_example['solution']
            model_answer = extract_bbox_answer(original_output)
            
            # Count correct answers
            correct = 0
            if model_answer is not None and iou(model_answer, ground_truth) > 0.5:
                correct = 1
            correct_number += correct
            
            # Create a result dictionary for this example
            result = {
                'image': input_example['image'],
                'question': input_example['problem'],
                'ground_truth': ground_truth,
                'model_output': original_output,
                'extracted_answer': model_answer,
                'correct': correct
            }
            final_output.append(result)

        # Calculate and print accuracy
        accuracy = correct_number / len(data) * 100
        print(f"\nAccuracy of {ds}: {accuracy:.2f}%")

        # Save results to a JSON file
        output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
        output_dir = os.path.dirname(output_path)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(output_path, "w") as f:
            json.dump({
                'accuracy': accuracy,
                'results': final_output
            }, f, indent=4)

        print(f"Results saved to {output_path}")
        print("-"*100)

    # Synchronize all processes
    dist.barrier()







================================================
FILE: src/open-r1-multimodal/.gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# UV
#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#uv.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# PyPI configuration file
.pypirc

# Temp folders
data/
wandb/
scripts/
checkpoints/
.vscode/

================================================
FILE: src/open-r1-multimodal/LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: src/open-r1-multimodal/Makefile
================================================
.PHONY: style quality

# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src

check_dirs := src

style:
	black --line-length 119 --target-version py310 $(check_dirs) setup.py
	isort $(check_dirs) setup.py

quality:
	black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
	isort --check-only $(check_dirs) setup.py
	flake8 --max-line-length 119 $(check_dirs) setup.py


# Evaluation

evaluate:


================================================
FILE: src/open-r1-multimodal/configs/ddp.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false


================================================
FILE: src/open-r1-multimodal/configs/zero2.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

================================================
FILE: src/open-r1-multimodal/configs/zero3.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false


================================================
FILE: src/open-r1-multimodal/local_scripts/zero2.json
================================================
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "none",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": false,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 100,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

================================================
FILE: src/open-r1-multimodal/local_scripts/zero3.json
================================================
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 100,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

================================================
FILE: src/open-r1-multimodal/local_scripts/zero3.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false


================================================
FILE: src/open-r1-multimodal/local_scripts/zero3_offload.json
================================================
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "gather_16bit_weights_on_model_save": true
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "steps_per_print": 1e5,
    "wall_clock_breakdown": false
}

================================================
FILE: src/open-r1-multimodal/local_scripts/zero_stage2_config.json
================================================
{
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 1e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 1e8,
    "contiguous_gradients": true
  },
  "fp16": {
    "enabled": "auto",
    "auto_cast": true,
    "loss_scale": 0,
    "initial_scale_power": 32,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "steps_per_print": 2000,
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "wall_clock_breakdown": false
}


================================================
FILE: src/open-r1-multimodal/setup.cfg
================================================
[isort]
default_section = FIRSTPARTY
ensure_newline_before_comments = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = open_r1
known_third_party =
    transformers
    datasets
    fugashi
    git
    h5py
    matplotlib
    nltk
    numpy
    packaging
    pandas
    psutil
    pytest
    rouge_score
    sacrebleu
    seqeval
    sklearn
    streamlit
    torch
    tqdm

line_length = 119
lines_after_imports = 2
multi_line_output = 3
use_parentheses = True

[flake8]
ignore = E203, E501, E741, W503, W605
max-line-length = 119
per-file-ignores =
    # imported but unused
    __init__.py: F401

[tool:pytest]
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS

================================================
FILE: src/open-r1-multimodal/setup.py
================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py


import re
import shutil
from pathlib import Path

from setuptools import find_packages, setup


# Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
if stale_egg_info.exists():
    print(
        (
            "Warning: {} exists.\n\n"
            "If you recently updated open_r1, this is expected,\n"
            "but it may prevent open_r1 from installing in editable mode.\n\n"
            "This directory is automatically generated by Python's packaging tools.\n"
            "I will remove it now.\n\n"
            "See https://github.com/pypa/pip/issues/5466 for details.\n"
        ).format(stale_egg_info)
    )
    shutil.rmtree(stale_egg_info)


# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
#   * If a dependency is fast-moving (e.g. transformers), pin to the exact version
_deps = [
    "accelerate>=1.2.1",
    "bitsandbytes>=0.43.0",
    "black>=24.4.2",
    "datasets>=3.2.0",
    "deepspeed==0.15.4",
    "distilabel[vllm,ray,openai]>=1.5.2",
    "einops>=0.8.0",
    "flake8>=6.0.0",
    "hf_transfer>=0.1.4",
    "huggingface-hub[cli]>=0.19.2,<1.0",
    "isort>=5.12.0",
    "liger_kernel==0.5.2",
    # "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
    "math-verify",  # Used for math verification in grpo
    "packaging>=23.0",
    "parameterized>=0.9.0",
    "pytest",
    "safetensors>=0.3.3",
    "sentencepiece>=0.1.99",
    "torch>=2.5.1",
    "transformers==4.49.0",
    "trl @ git+https://github.com/huggingface/trl.git@main",
    "vllm==0.6.6.post1",
    "wandb>=0.19.1",
    "pillow",
]

# this is a lookup table with items like:
#
# tokenizers: "tokenizers==0.9.4"
# packaging: "packaging"
#
# some of the values are versioned whereas others aren't.
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}


def deps_list(*pkgs):
    return [deps[pkg] for pkg in pkgs]


extras = {}
extras["tests"] = deps_list("pytest", "parameterized")
extras["torch"] = deps_list("torch")
extras["quality"] = deps_list("black", "isort", "flake8")
# extras["eval"] = deps_list("lighteval", "math-verify")
extras["eval"] = deps_list("math-verify")
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]

# core dependencies shared across the whole project - keep this to a bare minimum :)
install_requires = [
    deps["accelerate"],
    deps["bitsandbytes"],
    deps["einops"],
    deps["datasets"],
    deps["deepspeed"],
    deps["hf_transfer"],
    deps["huggingface-hub"],
    deps["liger_kernel"],
    deps["packaging"],  # utilities from PyPA to e.g., compare versions
    deps["safetensors"],
    deps["sentencepiece"],
    deps["transformers"],
    deps["trl"],
]

setup(
    name="open-r1",
    version="0.1.0.dev0",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
    author="The Hugging Face team (past and future)",
    author_email="lewis@huggingface.co",
    description="Open R1",
    # long_description=open("README.md", "r", encoding="utf-8").read(),
    long_description_content_type="text/markdown",
    keywords="llm inference-time compute reasoning",
    license="Apache",
    url="https://github.com/huggingface/open-r1",
    package_dir={"": "src"},
    packages=find_packages("src"),
    zip_safe=False,
    extras_require=extras,
    python_requires=">=3.10.9",
    install_requires=install_requires,
    classifiers=[
        "Development Status :: 3 - Alpha",
        "Intended Audience :: Developers",
        "Intended Audience :: Education",
        "Intended Audience :: Science/Research",
        "License :: OSI Approved :: Apache Software License",
        "Operating System :: OS Independent",
        "Programming Language :: Python :: 3",
        "Programming Language :: Python :: 3.10",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
    ],
)


================================================
FILE: src/open-r1-multimodal/src/open_r1/__init__.py
================================================


================================================
FILE: src/open-r1-multimodal/src/open_r1/configs.py
================================================
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional

import trl


# TODO: add the shared options with a mixin to reduce code duplication
@dataclass
class GRPOConfig(trl.GRPOConfig):
    """
    args for callbacks, benchmarks etc
    """

    benchmarks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
    )
    callbacks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
    )
    system_prompt: Optional[str] = field(
        default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
    )
    hub_model_revision: Optional[str] = field(
        default="main", metadata={"help": "The Hub model branch to push the model to."}
    )
    overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
    push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
    wandb_entity: Optional[str] = field(
        default=None,
        metadata={"help": ("The entity to store runs under.")},
    )
    wandb_project: Optional[str] = field(
        default=None,
        metadata={"help": ("The project to store runs under.")},
    )


@dataclass
class SFTConfig(trl.SFTConfig):
    """
    args for callbacks, benchmarks etc
    """

    benchmarks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
    )
    callbacks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
    )
    system_prompt: Optional[str] = field(
        default=None,
        metadata={"help": "The optional system prompt to use for benchmarking."},
    )
    hub_model_revision: Optional[str] = field(
        default="main",
        metadata={"help": "The Hub model branch to push the model to."},
    )
    overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
    push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
    wandb_entity: Optional[str] = field(
        default=None,
        metadata={"help": ("The entity to store runs under.")},
    )
    wandb_project: Optional[str] = field(
        default=None,
        metadata={"help": ("The project to store runs under.")},
    )

================================================
FILE: src/open-r1-multimodal/src/open_r1/evaluate.py
================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom evaluation tasks for LightEval."""

from lighteval.metrics.dynamic_metrics import (
    ExprExtractionConfig,
    LatexExtractionConfig,
    multilingual_extractive_match_metric,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language


metric = multilingual_extractive_match_metric(
    language=Language.ENGLISH,
    fallback_mode="first_match",
    precision=5,
    gold_extraction_target=(LatexExtractionConfig(),),
    pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
    aggregation_function=max,
)


def prompt_fn(line, task_name: str = None):
    """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
    return Doc(
        task_name=task_name,
        query=line["problem"],
        choices=[line["solution"]],
        gold_index=0,
    )


# Define tasks
aime24 = LightevalTaskConfig(
    name="aime24",
    suite=["custom"],
    prompt_function=prompt_fn,
    hf_repo="HuggingFaceH4/aime_2024",
    hf_subset="default",
    hf_avail_splits=["train"],
    evaluation_splits=["train"],
    few_shots_split=None,
    few_shots_select=None,
    generation_size=32768,
    metric=[metric],
    version=1,
)
math_500 = LightevalTaskConfig(
    name="math_500",
    suite=["custom"],
    prompt_function=prompt_fn,
    hf_repo="HuggingFaceH4/MATH-500",
    hf_subset="default",
    hf_avail_splits=["test"],
    evaluation_splits=["test"],
    few_shots_split=None,
    few_shots_select=None,
    generation_size=32768,
    metric=[metric],
    version=1,
)

# Add tasks to the table
TASKS_TABLE = []
TASKS_TABLE.append(aime24)
TASKS_TABLE.append(math_500)

# MODULE LOGIC
if __name__ == "__main__":
    print([t["name"] for t in TASKS_TABLE])
    print(len(TASKS_TABLE))


================================================
FILE: src/open-r1-multimodal/src/open_r1/generate.py
================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration


def build_distilabel_pipeline(
    model: str,
    base_url: str = "http://localhost:8000/v1",
    prompt_column: Optional[str] = None,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
    max_new_tokens: int = 8192,
    num_generations: int = 1,
) -> Pipeline:
    generation_kwargs = {"max_new_tokens": max_new_tokens}

    if temperature is not None:
        generation_kwargs["temperature"] = temperature

    if top_p is not None:
        generation_kwargs["top_p"] = top_p

    with Pipeline().ray() as pipeline:
        TextGeneration(
            llm=OpenAILLM(
                base_url=base_url,
                api_key="something",
                model=model,
                # thinking can take some time...
                timeout=10 * 60,
                generation_kwargs=generation_kwargs,
            ),
            input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
            input_batch_size=64,  # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
            num_generations=num_generations,
        )

    return pipeline


if __name__ == "__main__":
    import argparse

    from datasets import load_dataset

    parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
    parser.add_argument(
        "--hf-dataset",
        type=str,
        required=True,
        help="HuggingFace dataset to load",
    )
    parser.add_argument(
        "--hf-dataset-config",
        type=str,
        required=False,
        help="Dataset config to use",
    )
    parser.add_argument(
        "--hf-dataset-split",
        type=str,
        default="train",
        help="Dataset split to use",
    )
    parser.add_argument("--prompt-column", type=str, default="prompt")
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Model name to use for generation",
    )
    parser.add_argument(
        "--vllm-server-url",
        type=str,
        default="http://localhost:8000/v1",
        help="URL of the vLLM server",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        help="Temperature for generation",
    )
    parser.add_argument(
        "--top-p",
        type=float,
        help="Top-p value for generation",
    )
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=8192,
        help="Maximum number of new tokens to generate",
    )
    parser.add_argument(
        "--num-generations",
        type=int,
        default=1,
        help="Number of generations per problem",
    )
    parser.add_argument(
        "--hf-output-dataset",
        type=str,
        required=False,
        help="HuggingFace repo to push results to",
    )
    parser.add_argument(
        "--private",
        action="store_true",
        help="Whether to make the output dataset private when pushing to HF Hub",
    )

    args = parser.parse_args()

    print("\nRunning with arguments:")
    for arg, value in vars(args).items():
        print(f"  {arg}: {value}")
    print()

    print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
    dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
    print("Dataset loaded!")

    pipeline = build_distilabel_pipeline(
        model=args.model,
        base_url=args.vllm_server_url,
        prompt_column=args.prompt_column,
        temperature=args.temperature,
        top_p=args.top_p,
        max_new_tokens=args.max_new_tokens,
        num_generations=args.num_generations,
    )

    print("Running generation pipeline...")
    distiset = pipeline.run(dataset=dataset, use_cache=False)
    print("Generation pipeline finished!")

    if args.hf_output_dataset:
        print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
        distiset.push_to_hub(args.hf_output_dataset, private=args.private)
        print("Dataset pushed!")


================================================
FILE: src/open-r1-multimodal/src/open_r1/grpo_jsonl.py
================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import pathlib
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional
from babel.numbers import parse_decimal
from utils.math import compute_score
from datasets import load_dataset, load_from_disk
from transformers import Qwen2VLForConditionalGeneration

from math_verify import parse, verify
from trainer import VLMGRPOTrainer, GRPOConfig
# from trainer import VLMGRPOTrainer, GRPOConfig
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
import PIL
from Levenshtein import ratio
from utils.pycocotools.coco import COCO
from utils.pycocotools.cocoeval import COCOeval
import json
import math
from json_repair import repair_json

from vlm_modules import *
from typing import Tuple
from transformers.utils import logging
from transformers import AutoProcessor, AutoTokenizer

from openai import OpenAI


logger = logging.get_logger(__name__)

client = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY", ""),  # Must be set via environment variable
    base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
)

from qwen2_5vl_monkey_patch import monkey_patch_qwen2_5vl_flash_attn, monkey_patch_qwen2_5vl_forward, monkey_patch_torch_load

monkey_patch_qwen2_5vl_flash_attn()    
monkey_patch_torch_load()

tokenizer = None

def initialize_tokenizer(model_path):
    global tokenizer
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(model_path,local_files_only=True)
        print(f"Is Fast Tokenizer? {tokenizer.is_fast}")
    return tokenizer

@dataclass
class GRPOScriptArguments(ScriptArguments):
    """
    Script arguments for the GRPO training script.
    """
    data_file_paths: str = field(
        default=None,
        metadata={"help": "Paths to data files, separated by ':'"},
    )
    image_folders: str = field(
        default=None,
        metadata={"help": "Paths to image folders, separated by ':'"},
    )
    arrow_cache_dir: str = field(
        default=None,
        metadata={"help": "Path to arrow cache directory"},
    )
    val_split_ratio: float = field(
        default=0.0,
        metadata={"help": "Ratio of validation split, default 0.0"},
    )
    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
    )
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
    )
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
    )
    max_anyres_num: Optional[int] = field(
        default=12,
        metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
    )
    reward_method: Optional[str] = field(
        default=None,
        metadata={
            "help": "Choose reward method: 'default', 'mcp', ..."
        },
    )
    task_type: Optional[str] = field(
        default=None,
        metadata={"help": "Choose task type: 'default', 'gui', ..."},
    )
    is_reward_customized_from_vlm_module: bool = field(
        default=False,
        metadata={"help": "Whether to use a customized reward from vlm module"},
    )

def extract_choice(text):
    # 1. Clean and normalize text
    text = text.upper()  # Convert to uppercase
    text = re.sub(r'\s+', ' ', text)  # Normalize spaces

    # 2. Choice should not have uppercase letters before or after
    choices = re.findall(r'(?<![A-Z])([A-Z])(?=[\.\,\?\!\:\;]|$)', text)

    if not choices:
        return None

    # 3. If only one choice, return it directly
    if len(choices) == 1:
        return choices[0]

    # 4. If multiple choices, use heuristic rules
    choice_scores = {choice: 0 for choice in choices}

    # 4.1 Keywords around choices get points
    keywords = [
        '答案', '选择', '正确', '是', '对',
        'answer', 'correct', 'choose', 'select', 'right',
        '认为', '应该', '觉得', 'think', 'believe', 'should'
    ]

    # Get context for each choice (20 chars before and after)
    for choice in choices:
        pos = text.find(choice)
        context = text[max(0, pos-20):min(len(text), pos+20)]

        # Add points for keywords
        for keyword in keywords:
            if keyword.upper() in context:
                choice_scores[choice] += 1

        # Add points if choice is near the end (usually final answer)
        if pos > len(text) * 0.7:  # In last 30% of text
            choice_scores[choice] += 2

        # Add points if followed by punctuation
        if pos < len(text) - 1 and text[pos+1] in '。.!!,,':
            choice_scores[choice] += 1

    # Return highest scoring choice
    return max(choice_scores.items(), key=lambda x: x[1])[0]

def evaluate_answer_similarity(student_answer, ground_truth):
    """Use llm to evaluate answer similarity."""
    try:
        response = client.chat.completions.create(
            model="qwen2.5:7b",
            messages=[
                {
                    "role": "user",
                    "content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed."
                },
                {
                    "role": "user",
                    "content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:"
                }
            ],
            temperature=0
        )
        result = response.choices[0].message.content.strip()
        return float(result)
    
    except Exception as e:
        print(f"Error in GPT evaluation: {e}")
        # If API call fails, fall back to simple text matching
        return 1.0 if student_answer ==ground_truth else 0.0

def llm_reward(content, sol, **kwargs):
    # Extract answer from content if it has think/answer tags
    sol_match = re.search(r'<answer>(.*?)</answer>', sol)
    ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
    
    # Extract answer from content if it has think/answer tags
    content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
    student_answer = content_matches[-1].strip() if content_matches else content.strip()
    return evaluate_answer_similarity(student_answer, ground_truth)

def mcq_reward(content, sol, **kwargs):
    # For multiple choice, extract and compare choices
    sol_match = re.search(r'<answer>(.*?)</answer>', sol)
    ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
    has_choices = extract_choice(ground_truth)
    correct_choice = has_choices.upper() if has_choices else sol.strip()

    # Extract answer from content if it has think/answer tags
    content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
    student_answer = content_match.group(1).strip() if content_match else content.strip()
    student_choice = extract_choice(student_answer)
    if student_choice:
        reward = 1.0 if student_choice == correct_choice else 0.0
    else:
        reward = 0.0

    return reward


def yes_no_reward(content, sol, **kwargs):
    content = content.lower()
    sol = sol.lower()

    # Extract answer from solution if it has think/answer tags
    sol_match = re.search(r'<answer>(.*?)</answer>', sol)
    ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()

    # Extract answer from content if it has think/answer tags
    content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
    student_answer = content_match.group(1).strip() if content_match else content.strip()

    ground_yes_no = re.search(r'(yes|no)', ground_truth)
    ground_yes_no = ground_yes_no.group(1) if ground_yes_no else ''
    student_yes_no = re.search(r'(yes|no)', student_answer)
    student_yes_no = student_yes_no.group(1) if student_yes_no else ''

    reward = 1.0 if ground_yes_no == student_yes_no else 0.0

    return reward

# score_type: 0 for mAP, 1 for mAP 50
def calculate_map(pred_bbox_list, gt_bbox_list, score_type=0):
    # Calculate mAP

    # Initialize COCO object for ground truth
    gt_json = {"annotations": [], "images": [], "categories": []}
    gt_json["images"] = [{
        "id": 0,
        "width": 2048,
        "height": 2048,
        "file_name": "image_0.jpg"
    }]

    gt_json["categories"] = []

    cats2id = {}
    cat_count = 0
    for idx, gt_bbox in enumerate(gt_bbox_list):
        if gt_bbox["label"] not in cats2id:
            cats2id[gt_bbox["label"]] = cat_count
            gt_json["categories"].append({
                "id": cat_count,
                "name": gt_bbox["label"]
            })
            cat_count += 1
        
        gt_json["annotations"].append({
            "id": idx+1,
            "image_id": 0,
            "category_id": cats2id[gt_bbox["label"]],
            "bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
            "area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
            "iscrowd": 0
        })
    coco_gt = COCO(gt_json)

    dt_json = []
    for idx, pred_bbox in enumerate(pred_bbox_list):
        try:
            dt_json.append({
                "image_id": 0,
                "category_id": cats2id[pred_bbox["label"]],
                "bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
                "score": 1.0,
                "area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
            })
        except:
            pass
    
    if len(dt_json) == 0:
        return 0.0
    
    coco_dt = coco_gt.loadRes(dt_json)
    coco_eval = COCOeval(coco_gt, coco_dt, "bbox")

    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    return coco_eval.stats[score_type]

def map_reward(content, sol, length_reward=False, score_type=0, **kwargs):
    """
    Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes.
    
    Args:
        content (str): String containing predicted bounding boxes in JSON format
        sol (str): String containing ground truth bounding boxes in JSON format
        length_reward (bool, optional): Whether to include length penalty in reward calculation. Defaults to False.
        score_type (int, optional): Type of COCO evaluation metric to use. Defaults to 0 (mAP).
        **kwargs: Additional keyword arguments
        
    Returns:
        float: mAP reward score between 0 and 1. If length_reward is True, the score is multiplied by a length penalty factor.
    """
    # Extract JSON content between ```json tags
    pattern = r'```json(.*?)```'
    json_match = re.findall(pattern, sol, re.DOTALL)
    bbox_json = json_match[-1].strip() if json_match else None

    # Parse ground truth JSON to get bbox list
    gt_bbox_list = []
    if bbox_json:
        bbox_data = json.loads(bbox_json)
        gt_bbox_list = [item for item in bbox_data]
    
    # Parse predicted JSON to get bbox list
    pred_bbox_list = []
    json_match = re.findall(pattern, content, re.DOTALL)
    if json_match:
        try:
            bbox_data = json.loads(json_match[-1].strip())
            pred_bbox_list = [item for item in bbox_data]
        except:
            # Return empty list if JSON parsing fails
            pred_bbox_list = []

    # Calculate mAP if both prediction and ground truth exist
    if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
        bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list, score_type=score_type)
    elif len(pred_bbox_list) == 0 and len(gt_bbox_list) == 0:
        bbox_reward = 1.0
    else:
        bbox_reward = 0.0
    
    if length_reward:
        # Calculate length penalty based on ratio of ground truth to predicted bounding boxes
        gt_length = len(gt_bbox_list)
        pred_length = len(pred_bbox_list)
        # Full score if prediction has fewer boxes than ground truth, otherwise penalize proportionally
        length_score = 1.0 if gt_length >= pred_length else gt_length/pred_length
        return bbox_reward * length_score
    else:
        return bbox_reward

def od_reward(content, sol, score_type=0, **kwargs):
    """
    Calculate reward for object detection task by comparing predicted and ground truth answers.
    
    Args:
        content (str): Model's predicted answer containing bounding box annotations
        sol (str): Ground truth answer containing bounding box annotations 
        score_type (int): Type of COCO evaluation metric to use (default: 0 for mAP)
        **kwargs: Additional keyword arguments
        
    Returns:
        float: Reward score between 0 and 1 based on mAP between predicted and ground truth boxes
    """
    # Pattern to extract content between <answer> tags
    match_pattern = r'<answer>(.*?)</answer>'

    # Extract ground truth answer
    sol_match = re.search(match_pattern, sol, re.DOTALL)
    ground_truth = sol_match.group(1).strip() if sol_match else None

    # Extract predicted answer (using last match if multiple)
    content_match = re.findall(match_pattern, content, re.DOTALL)
    student_answer = content_match[-1].strip() if content_match else None

    # Return 0 if no prediction
    if student_answer is None:
        return 0.0
    # Return 1 if both prediction and ground truth are None
    elif ground_truth == "None" and student_answer == "None":
        return 1.0
    # Otherwise calculate mAP between prediction and ground truth
    else:
        return map_reward(student_answer, ground_truth, score_type=score_type)

def odLength_reward(content, sol, **kwargs):
    """
    Calculate reward for object detection task with length penalty.
    
    Args:
        content (str): Model's predicted answer containing bounding box annotations
        sol (str): Ground truth answer containing bounding box annotations
        **kwargs: Additional keyword arguments
        
    Returns:
        float: Reward score between 0 and 1 based on mAP and length penalty
    """
    # Pattern to extract content between <answer> tags
    match_pattern = r'<answer>(.*?)</answer>'

    # Extract ground truth answer
    sol_match = re.search(match_pattern, sol, re.DOTALL)
    ground_truth = sol_match.group(1).strip() if sol_match else None
    # Extract predicted answer (using last match if multiple)
    content_match = re.findall(match_pattern, content, re.DOTALL)
    student_answer = content_match[-1].strip() if content_match else None

    # Return 0 if no prediction
    if student_answer is None:
        return 0.0
    # Return 1 if both prediction and ground truth are None
    elif ground_truth == "None" and student_answer == "None":
        return 1.0
    # Calculate mAP with length penalty
    else:
        bbox_reward = map_reward(student_answer, ground_truth, length_reward=True, score_type=0)
        return bbox_reward

def iou(box1, box2):
    inter_x1 = max(box1[0], box2[0])
    inter_y1 = max(box1[1], box2[1])
    inter_x2 = min(box1[2]-1, box2[2]-1)
    inter_y2 = min(box1[3]-1, box2[3]-1)
    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
        inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
    else:
        inter = 0
    union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
    return float(inter)/union

def detection_score(content, sol, iou_threshold=0.5, alpha=0.7, beta=0.0, gamma=0.3):
    pattern = r'```json(.*?)```'
    json_match = re.search(pattern, clean_text(content), re.DOTALL)
    content_bbox_json = json_match.group(1).strip() if json_match else None
    if content_bbox_json:
        try:
            bbox_data = json.loads(content_bbox_json)
            pred_boxes = [item for item in bbox_data]
        except:
            pred_boxes = []

    else:
        pred_boxes = []

    pattern = r'```json(.*?)```'
    json_match = re.search(pattern, clean_text(sol), re.DOTALL)
    sol_bbox_json = json_match.group(1).strip() if json_match else None
    if sol_bbox_json:
        bbox_data = json.loads(sol_bbox_json)
        gt_boxes = [item for item in bbox_data]
    else:
        gt_boxes = []

    """
    Calculate the comprehensive score for object detection
    
    Parameters:
        pred_boxes: List of predicted boxes, each element is in the format {"bbox_2d": [x1, y1, x2, y2], "label": "category name"}
        gt_boxes: List of ground truth boxes, each element is in the format {"bbox_2d": [x1, y1, x2, y2], "label": "category name"}
        iou_threshold: IoU threshold, default is 0.5
        alpha: Position accuracy weight, default is 0.7
        beta: Label accuracy weight, default is 0.0
        gamma: Completeness weight (penalty for missed/false detections), default is 0.3
        
    Returns:
        Comprehensive score, ranging from [0.0, 1.0]
    """
    # Handle edge cases
    if len(gt_boxes) == 0:
        return 1.0 if not pred_boxes else 0.0
    
    if len(pred_boxes) == 0:
        return 0.0
    
    # Initialize matching results
    matches = []  # Store matched pairs of predicted and ground truth boxes
    unmatched_preds = list(range(len(pred_boxes)))  # Indices of unmatched predicted boxes
    unmatched_gts = list(range(len(gt_boxes)))  # Indices of unmatched ground truth boxes
    
    # Calculate IoU matrix between all predicted and ground truth boxes
    iou_matrix = []
    for pred_idx, pred_box in enumerate(pred_boxes):
        iou_row = []
        for gt_idx, gt_box in enumerate(gt_boxes):
            try:
                curr_iou = iou(pred_box["bbox_2d"], gt_box["bbox_2d"])
            except:
                curr_iou = 0.0
            iou_row.append(curr_iou)
        iou_matrix.append(iou_row)
    
    # Greedy matching: find the best match for each predicted box
    while unmatched_preds and unmatched_gts:
        # Find the maximum IoU
        max_iou = -1
        max_pred_idx = -1
        max_gt_idx = -1
        
        for pred_idx in unmatched_preds:
            for gt_idx in unmatched_gts:
                curr_iou = iou_matrix[pred_idx][gt_idx]
                if curr_iou > max_iou:
                    max_iou = curr_iou
                    max_pred_idx = pred_idx
                    max_gt_idx = gt_idx
        
        # Stop matching if the maximum IoU is below the threshold
        if max_iou < iou_threshold:
            break
        
        # Record matching results
        try:
            pred_label = pred_boxes[max_pred_idx]["label"].lower()
        except:
            pred_box = ""
        try:
            gt_label = gt_boxes[max_gt_idx]["label"].lower()
        except:
            gt_label = ""
        label_correct = (pred_label == gt_label)
        
        if label_correct:
            matches.append({
                "pred_idx": max_pred_idx,
                "gt_idx": max_gt_idx,
                "iou": max_iou,
                "label_correct": label_correct
            })
        else:
            matches.append({
                "pred_idx": max_pred_idx,
                "gt_idx": max_gt_idx,
                "iou": 0,
                "label_correct": label_correct
            })
        
        # Remove matched boxes from the unmatched list
        unmatched_preds.remove(max_pred_idx)
        unmatched_gts.remove(max_gt_idx)
    
    # Calculate position accuracy score (average IoU)
    position_score = sum(m["iou"] for m in matches) / len(gt_boxes) if matches else 0.0
    
    # Calculate label accuracy score
    label_score = sum(1.0 for m in matches if m["label_correct"]) / len(gt_boxes) if matches else 0.0
    
    # Calculate completeness score (considering missed and false detections)
    # Miss rate = number of unmatched ground truth boxes / total number of ground truth boxes
    # False alarm rate = number of unmatched predicted boxes / total number of predicted boxes
    miss_rate = len(unmatched_gts) / len(gt_boxes)
    false_alarm_rate = len(unmatched_preds) / len(pred_boxes) if pred_boxes else 0.0
    
    # Completeness score = 1 - (miss rate + false alarm rate) / 2
    completeness_score = 1.0 - (miss_rate + false_alarm_rate) / 2.0
    
    # Calculate the final comprehensive score
    final_score = (
        alpha * position_score + 
        beta * label_score + 
        gamma * completeness_score
    ) / (alpha + beta + gamma)

    return final_score

def cosine_reward(content, tokenizer, acc_reward, **kwargs):
    #https://arxiv.org/abs/2502.03373
    min_len_value_wrong = 0.0
    max_len_value_wrong = -0.5
    min_len_value_correct = 1.0
    max_len_value_correct = 0.5
    cosine_max_len = 1024

    # processing_class = AutoProcessor.from_pretrained(model_path)
    # tokenizer = processing_class.tokenizer
    
    gen_len = len(tokenizer.encode(content))
    acc_reward = 1.0
    is_correct = acc_reward >= 0.7
    
    if is_correct:
        # Swap min/max for correct answers
        min_value = max_len_value_correct
        max_value = min_len_value_correct
    else:
        min_value = min_len_value_wrong
        max_value = max_len_value_wrong

    reward = max_value - (max_value - min_value) * (1 - math.cos(gen_len * math.pi / cosine_max_len)) / 2

    return reward

def repetition_reward(content, **kwargs):
    max_penalty = -1.0

    if content == '':
        return 0.0

    # First, try to extract explicitly marked JSON sections
    pattern = r'```json(.*?)```'
    json_match = re.search(pattern, content, re.DOTALL)
    
    if json_match:
        bbox_json = json_match.group(1).strip()
    else:
        # If no explicitly marked JSON is found, try to find any possible JSON sections
        pattern = r'```(.*?)```'
        json_match = re.search(pattern, content, re.DOTALL)
        bbox_json = json_match.group(1).strip() if json_match else None
        
        # If still not found, try to find possible JSON array sections
        if not bbox_json:
            pattern = r'\[\s*{.*?"bbox_2d".*?"label".*?}\s*\]'
            json_match = re.search(pattern, content, re.DOTALL)
            bbox_json = json_match.group(0) if json_match else None
    
    # Try to parse JSON data
    if bbox_json:
        try:
            # Try direct parsing
            data = json.loads(bbox_json)
        except json.JSONDecodeError:
            try:
                # If direct parsing fails, try using json_repair to repair
                repaired_json = repair_json(bbox_json)
                data = json.loads(repaired_json)
            except:
                # If repair also fails, switch to plain text processing
                data = None
        if data and isinstance(data, list):
            # Ensure data is in list format
            try:
                # For JSON data, set ngram_size to 1
                ngram_size = 1
                # Combine 'bbox_2d' and 'label' of each object into a string
                items = []
                for item in data:
                    if 'bbox_2d' in item and 'label' in item:
                        items.append(f"{item['bbox_2d']}_{item['label']}")
                
                @staticmethod
                def zipngram(text: list, ngram_size: int):
                    return zip(*[text[i:] for i in range(ngram_size)])
                
                ngrams = set()
                total = 0

                for ng in zipngram(items, ngram_size):
                    ngrams.add(ng)
                    total += 1

                if total == 0:
                    return 0.0

                scaling = 1 - len(ngrams) / total
                reward = scaling * max_penalty

                return reward
            except KeyError:
                # If necessary keys are missing, switch to plain text processing
                pass
    
    # If no JSON section is found or JSON processing fails, treat as plain text
    ngram_size = 6
    
    if len(content.split()) < ngram_size:
        return 0.0
    
    @staticmethod
    def zipngram(text: str, ngram_size: int):
        words = text.lower().split()
        return zip(*[words[i:] for i in range(ngram_size)])
    
    ngrams = set()
    total = 0

    for ng in zipngram(content, ngram_size):
        ngrams.add(ng)
        total += 1

    scaling = 1 - len(ngrams) / total
    reward = scaling * max_penalty

    return reward


def repetition_rewards(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []

    for content, sol in zip(contents, solution):
        reward = repetition_reward(content)
        rewards.append(reward)


        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
            image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
            problem = kwargs.get("problem")[0]
            if reward <= 0.0:  # this condition can be changed for debug
                with open(log_path+"_repetition.txt", "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"image_path: {image_path}\n")
                    f.write(f"problem: {problem}\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Solution: {sol}\n")     



    return rewards


def cosine_rewards(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []

    for content, sol in zip(contents, solution):
        clean_content = clean_text(content)
        sol = clean_text(sol)
        if sol == "none":
            if clean_content == "none":
                acc_reward = 1.0
            else:
                acc_reward = 0.0
        else:
            acc_reward = detection_score(clean_content, sol)
        reward = cosine_reward(content, tokenizer, acc_reward)
        rewards.append(reward)

        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
            image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
            problem = kwargs.get("problem")[0]
            if reward <=1.0:  # this condition can be changed for debug
                with open(log_path+"_cosine.txt", "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"image_path: {image_path}\n")
                    f.write(f"problem: {problem}\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Solution: {sol}\n")   

    return rewards

def numeric_reward(content, sol, **kwargs):
    content = clean_text(content)
    sol = clean_text(sol)
    try:
        content, sol = float(content), float(sol)
        return 1.0 if content == sol else 0.0
    except:
        return None
def math_reward(content, sol, **kwargs):
    content = clean_text(content)
    sol = clean_text(sol)
    return compute_score(content, sol)
def clean_text(text, exclue_chars=['\n', '\r']):
    # Extract content between <answer> and </answer> if present
    answer_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
    if answer_matches:
        # Use the last match
        text = answer_matches[-1]
    
    for char in exclue_chars:
        if char in ['\n', '\r']:
            # If there is a space before the newline, remove the newline
            text = re.sub(r'(?<=\s)' + re.escape(char), '', text)
            # If there is no space before the newline, replace it with a space
            text = re.sub(r'(?<!\s)' + re.escape(char), ' ', text)
        else:
            text = text.replace(char, ' ')
    
    # Remove leading and trailing spaces and convert to lowercase
    return text.strip().rstrip('.').lower()

def all_match_reward(content, sol, **kwargs):
    content = clean_text(content)
    sol = clean_text(sol)
    return 1.0 if content == sol else 0.0

def default_accuracy_reward(content, sol, **kwargs):
    reward = 0.0
        # Extract answer from solution if it has think/answer tags
    sol_match = re.search(r'<answer>(.*?)</answer>', sol)
    ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
    
    # Extract answer from content if it has think/answer tags
    content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
    student_answer = content_matches[-1].strip() if content_matches else content.strip()
    
    # Try symbolic verification first for numeric answers
    try:
        answer = parse(student_answer)
        if float(verify(answer, parse(ground_truth))) > 0:
            reward = 1.0
    except Exception:
        pass  # Continue to next verification method if this fails

    # If symbolic verification failed, try string matching or fuzzy matching
    if reward == 0.0:
        try: 
            # Check if ground truth contains numbers
            has_numbers = bool(re.search(r'\d', ground_truth))
            # Check if it's a multiple choice question
            has_choices = extract_choice(ground_truth)
            
            if has_numbers:
                # For numeric answers, use exact matching
                reward = numeric_reward(student_answer, ground_truth)
                if reward is None:
                    reward = ratio(clean_text(student_answer), clean_text(ground_truth))
            elif has_choices:
                # For multiple choice, extract and compare choices
                correct_choice = has_choices.upper()
                student_choice = extract_choice(student_answer)
                if student_choice:
                    reward = 1.0 if student_choice == correct_choice else 0.0
            else:
                # For text answers, use fuzzy matching
                reward = ratio(clean_text(student_answer), clean_text(ground_truth))
        except Exception:
            pass  # Keep reward as 0.0 if all methods fail

    return reward

def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")):
        # if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function
        if accu_reward_method == "mcq":
            reward = mcq_reward(content, sol)
        elif accu_reward_method == 'yes_no':
            reward = yes_no_reward(content, sol)
        elif accu_reward_method == 'llm':
            reward = llm_reward(content, sol)
        elif accu_reward_method == 'map':
            reward = map_reward(content, sol)
        elif accu_reward_method == 'math':
            reward = math_reward(content, sol)
        elif accu_reward_method == 'weighted_sum':
            clean_content = clean_text(content)
            sol = clean_text(sol)
            if sol == "none":
                if clean_content == "none":
                    reward = 1.0
                else:
                    reward = 0.0
            else:
                reward = detection_score(clean_content, sol)
        elif accu_reward_method == 'od_ap':
            reward = od_reward(content, sol)
        elif accu_reward_method == 'od_ap50':
            reward = od_reward(content, sol, score_type=1)
        elif accu_reward_method == 'odLength':
            reward = odLength_reward(content, sol)
        elif accu_reward_method == 'all_match':
            reward = all_match_reward(content, sol)
        else:
            reward = default_accuracy_reward(content, sol)  
        rewards.append(reward)
        
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
            image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
            problem = kwargs.get("problem")[0]
            if reward <= 1.0:  # this condition can be changed for debug
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"accu_reward_method: {accu_reward_method}\n")
                    f.write(f"image_path: {image_path}\n")
                    f.write(f"problem: {problem}\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Solution: {sol}\n")     

        
    return rewards

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]

    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    if os.getenv("DEBUG_MODE") == "true":
        log_path = os.getenv("LOG_PATH")
        with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
            f.write(f"------------- {current_time} Format reward -------------\n")
            for content, match in zip(completion_contents, matches):
                f.write(f"Content: {content}\n")
                f.write(f"Has format: {bool(match)}\n")

    return [1.0 if match else 0.0 for match in matches]


reward_funcs_registry = {
    "accuracy": accuracy_reward,
    "format": format_reward,
    "length": cosine_rewards,
    "repetition": repetition_rewards,
}

@dataclass
class GRPOModelConfig(ModelConfig):
    freeze_vision_modules: bool = False

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)


def get_vlm_module(model_name_or_path):
    if "qwen" in model_name_or_path.lower():
        return Qwen2VLModule
    else:
        raise ValueError(f"Unsupported model: {model_name_or_path}")

def main(script_args, training_args, model_args):
    # Load the VLM module
    vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
    print("using vlm module:", vlm_module_cls.__name__)
    question_prompt = vlm_module_cls.get_question_template(task_type=script_args.task_type)

    # Get reward functions 
    if script_args.is_reward_customized_from_vlm_module:
        reward_funcs = [vlm_module_cls.select_reward_func(func, script_args.task_type) for func in script_args.reward_funcs]
    else:
        reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
    print("reward_funcs:", reward_funcs)

    # Load the JSONL datasets
    import json
    from datasets import Dataset
    
    data_files = script_args.data_file_paths.split(":")
    image_folders = script_args.image_folders.split(":")
    
    if len(data_files) != len(image_folders):
        raise ValueError("Number of data files must match number of image folders")
    
    if script_args.reward_method is None:
        accu_reward_methods = ["default"] * len(data_files)
    else:
        accu_reward_methods = script_args.reward_method.split(":")
        assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}"

    
    if len(data_files) != len(image_folders):
        raise ValueError("Number of data files must match number of image folders")
    
    all_data = []
    for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods):
        with open(data_file, 'r') as f:
            for line in f:
                item = json.loads(line)
                if 'image' in item:
                    if isinstance(item['image'], str):
                        # Store image path instead of loading the image
                        item['image_path'] = [os.path.join(image_folder, item['image'])]
                        del item['image'] # remove the image column so that it can be loaded later
                    elif isinstance(item['image'], list):
                        # if the image is a list, then it is a list of images (for multi-image input)
                        item['image_path'] = [os.path.join(image_folder, image) for image in item['image']]
                        del item['image'] # remove the image column so that it can be loaded later
                    else:
                        raise ValueError(f"Unsupported image type: {type(item['image'])}")
                # Remove immediate image loading
                item['problem'] = item['conversations'][0]['value'].replace('<image>', '')
                
                # Handle solution that could be a float or string
                solution_value = item['conversations'][1]['value']
                if isinstance(solution_value, str):
                    item['solution'] = solution_value.replace('<answer>', '').replace('</answer>', '').strip()
                else:
                    # If it's a float or other non-string type, keep it as is
                    item['solution'] = str(solution_value)
                
                del item['conversations']
                item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value
                all_data.append(item)

    dataset = Dataset.from_list(all_data)

    def make_conversation_from_jsonl(example):
        if 'image_path' in example and example['image_path'] is not None:
            assert all(os.path.exists(p) for p in example['image_path']), f"Image paths do not exist: {example['image_path']}"
            # Don't load image here, just store the path
            return {
                'image_path': [p for p in example['image_path']],  # Store path instead of loaded image
                'problem': example['problem'],
                'solution': f"<answer> {example['solution']} </answer>",
                'accu_reward_method': example['accu_reward_method'],
                'prompt': [{
                    'role': 'user',
                    'content': [
                        *({'type': 'image', 'text': None} for _ in range(len(example['image_path']))),
                        {'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
                    ]
                }]
            }
        else:
            return {
                'problem': example['problem'],
                'solution': f"<answer> {example['solution']} </answer>",
                'accu_reward_method': example['accu_reward_method'],
                'prompt': [{
                    'role': 'user',
                    'content': [
                        {'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
                    ]
                }]
            }

    # Map the conversations
    dataset = dataset.map(make_conversation_from_jsonl, num_proc=8)
    # print(dataset[0])
    # Split dataset for validation if requested
    splits = {'train': dataset}
    if script_args.val_split_ratio > 0:
        train_val_split = dataset.train_test_split(
            test_size=script_args.val_split_ratio
        )
        splits['train'] = train_val_split['train']
        splits['validation'] = train_val_split['test']

    # Select trainer class based on vlm_trainer argument
    trainer_cls = VLMGRPOTrainer
    print("using trainer:", trainer_cls.__name__)
    initialize_tokenizer(model_args.model_name_or_path)
    
    # Initialize the GRPO trainer
    trainer = trainer_cls(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        vlm_module=vlm_module_cls(),
        train_dataset=splits['train'],
        eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None,
        peft_config=get_peft_config(model_args),
        freeze_vision_modules=model_args.freeze_vision_modules,
        attn_implementation=model_args.attn_implementation,
        max_pixels=script_args.max_pixels,
        min_pixels=script_args.min_pixels,
        max_anyres_num=script_args.max_anyres_num,
    )

    # Train and push the model to the Hub
    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub()


if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    if training_args.deepspeed and "zero3" in training_args.deepspeed:
        print("zero3 is used, qwen2_5vl forward monkey patch is applied")
        monkey_patch_qwen2_5vl_forward()
    main(script_args, training_args, model_args)


================================================
FILE: src/open-r1-multimodal/src/open_r1/qwen2_5vl_monkey_patch.py
================================================

# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
import torch
from typing import Tuple, Optional
def qwen2_5vl_vision_flash_attn_forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        seq_length = hidden_states.shape[0]
        q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
        # print(111, 222, 333, 444, 555, 666, 777, 888, 999)
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
            cos = emb.cos().float()
            sin = emb.sin().float()
        else:
            cos, sin = position_embeddings
            # Add this
            cos = cos.to(torch.float)
            sin = sin.to(torch.float)
        q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
        q = q.squeeze(0)
        k = k.squeeze(0)

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
            seq_length, -1
        )
        attn_output = self.proj(attn_output)
        return attn_output


def monkey_patch_qwen2_5vl_flash_attn():
    Qwen2_5_VLVisionFlashAttention2.forward = qwen2_5vl_vision_flash_attn_forward


# ----------------------- Fix the process pending bug when using data mixture of image-text data and pure-text under deepseed zero3-----------------------
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
from typing import List, Union
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
def qwen2_5vl_forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        rope_deltas: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        second_per_grid_ts: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

            has_images_global = False
            if pixel_values is not None:
                has_images_local = torch.tensor(1, device=input_ids.device)
            else:
                has_images_local = torch.tensor(0, device=input_ids.device)
            # Use all_reduce to ensure all GPUs know if there are images to process
            torch.distributed.all_reduce(has_images_local, op=torch.distributed.ReduceOp.MAX)
            has_images_global = has_images_local.item() > 0

            # If there are image inputs globally, ensure all GPUs call the visual model
            if has_images_global:
                if pixel_values is not None:   
                    pixel_values = pixel_values.type(self.visual.dtype)
                    image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
                    n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
                    n_image_features = image_embeds.shape[0]
                    if n_image_tokens != n_image_features:
                        raise ValueError(
                            f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
                        )
                    
                    mask = input_ids == self.config.image_token_id
                    mask_unsqueezed = mask.unsqueeze(-1)
                    mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                    image_mask = mask_expanded.to(inputs_embeds.device)
                    
                    image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                    inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
                else:
                    with torch.no_grad():
                        # Create a dummy image data for triggering parameter synchronization
                        dummy_pixel_values = torch.zeros((4, 1176), device=input_ids.device, dtype=self.visual.dtype)
                        dummy_grid_thw = torch.tensor([[1, 2, 2]], device=input_ids.device)
                        _ = self.visual(dummy_pixel_values, grid_thw=dummy_grid_thw)

            # Currently, video processing is not handled.
            if pixel_values_videos is not None:
                pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
                n_video_features = video_embeds.shape[0]
                if n_video_tokens != n_video_features:
                    raise ValueError(
                        f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
                    )

                mask = input_ids == self.config.video_token_id
                mask_unsqueezed = mask.unsqueeze(-1)
                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                video_mask = mask_expanded.to(inputs_embeds.device)

                video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

        # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
        if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
            # calculate RoPE index once per generation in the pre-fill stage only
            if (
                (cache_position is not None and cache_position[0] == 0)
                or self.rope_deltas is None
                or (past_key_values is None or past_key_values.get_seq_length() == 0)
            ):
                position_ids, rope_deltas = self.get_rope_index(
                    input_ids,
                    image_grid_thw,
                    video_grid_thw,
                    second_per_grid_ts,
                    attention_mask,
                )
                self.rope_deltas = rope_deltas
            # then use the prev pre-calculated rope-deltas to get the correct position ids
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = (
                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
                    if cache_position is not None
                    else 0
                )
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:  # otherwise `deltas` is an int `0`
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

        outputs = self.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Upcast to float if we need to compute the loss to avoid potential precision issues
            logits = logits.float()
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return Qwen2_5_VLCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            rope_deltas=self.rope_deltas,
        )

def monkey_patch_qwen2_5vl_forward():
    Qwen2_5_VLForConditionalGeneration.forward = qwen2_5vl_forward

# ----------------------- Set the Weights only as False in torch.load (In Pytorch 2.6, this is default as True)-----------------------
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils import logger, log_dist
def weigths_only_load(self, path: str, map_location=None):
    logger.info(f"[Torch] Loading checkpoint from {path}...")
    partition = torch.load(path, map_location=map_location, weights_only=False)
    logger.info(f"[Torch] Loaded checkpoint from {path}.")
    return partition

def monkey_patch_torch_load():
    TorchCheckpointEngine.load = weigths_only_load





================================================
FILE: src/open-r1-multimodal/src/open_r1/trainer/__init__.py
================================================
from .grpo_trainer import VLMGRPOTrainer
from .grpo_config import GRPOConfig

__all__ = ["VLMGRPOTrainer"]

================================================
FILE: src/open-r1-multimodal/src/open_r1/trainer/grpo_config.py
================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional

from transformers import TrainingArguments


@dataclass
class GRPOConfig(TrainingArguments):
    r"""
    Configuration class for the [`GRPOTrainer`].

    Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
    [`~transformers.TrainingArguments`] documentation.

    Using [`~transformers.HfArgumentParser`] we can turn this class into
    [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
    command line.

    Parameters:
        > Parameters that control the model and reference model

        model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
            Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
            argument of the [`GRPOTrainer`] is provided as a string.

        > Parameters that control the data preprocessing

        remove_unused_columns (`bool`, *optional*, defaults to `False`):
            Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
            requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
        max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
            Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
        num_generations (`int` or `None`, *optional*, defaults to `8`):
            Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
            must be divisible by this value.
        max_completion_length (`int` or `None`, *optional*, defaults to `256`):
            Maximum length of the generated completion.
        ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
            This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
            improving generation speed. However, disabling this option allows training models that exceed the VRAM
            capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
            with vLLM generation.

        > Parameters that control generation

        temperature (`float`, defaults to `0.9`):
            Temperature for sampling. The higher the temperature, the more random the completions.
        top_p (`float`, *optional*, defaults to `1.0`):
            Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
            `1.0` to consider all tokens.
        top_k (`int` or `None`, *optional*, defaults to `50`):
            Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
            disabled.
        min_p (`float` or `None`, *optional*, defaults to `None`):
            Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
            value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
        repetition_penalty (`float`, *optional*, defaults to `1.0`):
            Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
            Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
            tokens.
        cache_implementation (`str` or `None`, *optional*, defaults to `None`):
            Implementation of the cache method for faster generation when use_vllm is set to False.

        > Parameters that control generation acceleration powered by vLLM

        use_vllm (`bool`, *optional*, defaults to `False`):
            Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
            training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
        vllm_device (`str`, *optional*, defaults to `"auto"`):
            Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
            automatically select the next available GPU after the last one used for training. This assumes that
            training has not already occupied all available GPUs. If only one device is available, the device will be
            shared between both training and vLLM.
        vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
            Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
            device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
            improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
            during initialization.
        vllm_dtype (`str`, *optional*, defaults to `"auto"`):
            Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
            based on the model configuration. Find the supported values in the vLLM documentation.
        vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
            If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
            `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
            context size, which might be much larger than the KV cache, leading to inefficiencies.
        vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
            Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
            support this feature.
        vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
            Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.

        > Parameters that control the training

        learning_rate (`float`, *optional*, defaults to `1e-6`):
            Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
            [`~transformers.TrainingArguments`].
        beta (`float`, *optional*, defaults to `0.04`):
            KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
            speed, but may be numerically unstable for long training runs.
        num_iterations (`int`, *optional*, defaults to `1`):
            Number of iterations per batch (denoted as μ in the algorithm).
        epsilon (`float`, *optional*, defaults to `0.2`):
            Epsilon value for clipping.
        epsilon_high (`float` or `None`, *optional*, defaults to `None`):
            Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
            specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
        reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
            Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
            weighted equally with weight `1.0`.
        sync_ref_model (`bool`, *optional*, defaults to `False`):
            Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
            the `ref_model_mixup_alpha` parameter. This synchronization originites from the
            [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
        ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
            α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
            between the current policy and the previous reference policy during updates. The reference policy is
            updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
            must set `sync_ref_model=True`.
        ref_model_sync_steps (`int`, *optional*, defaults to `512`):
            τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
            frequently the current policy is synchronized with the reference policy. To use this parameter, you must
            set `sync_ref_model=True`.

        > Parameters that control the logging

        log_completions (`bool`, *optional*, defaults to `False`):
            Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
            installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
    """

    # Parameters that control the model and reference model
    model_init_kwargs: Optional[dict] = field(
        default=None,
        metadata={
            "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
            "argument of the `GRPOTrainer` is provided as a string."
        },
    )

    # Parameters that control the data preprocessing
    # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
    # additional columns to compute the reward
    remove_unused_columns: Optional[bool] = field(
        default=False,
        metadata={
            "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
            "that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
        },
    )
    max_prompt_length: Optional[int] = field(
        default=512,
        metadata={
            "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
        },
    )
    num_generations: Optional[int] = field(
        default=8,
        metadata={
            "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
            "must be divisible by this value."
        },
    )
    max_completion_length: Optional[int] = field(
        default=256,
        metadata={"help": "Maximum length of the generated completion."},
    )
    ds3_gather_for_generation: bool = field(
        default=True,
        metadata={
            "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
            "generation, improving generation speed. However, disabling this option allows training models that "
            "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
            "is not compatible with vLLM generation."
        },
    )

    # Parameters that control generation
    temperature: float = field(
        default=0.9,
        metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
    )
    top_p: float = field(
        default=1.0,
        metadata={
            "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
            "Set to 1.0 to consider all tokens."
        },
    )
    top_k: Optional[int] = field(
        default=50,
        metadata={
            "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
            "top-k-filtering is disabled."
        },
    )
    min_p: Optional[float] = field(
        default=None,
        metadata={
            "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
            "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
        },
    )
    repetition_penalty: float = field(
        default=1.0,
        metadata={
            "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
            "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
            "to repeat tokens."
        },
    )
    cache_implementation: Optional[str] = field(
        default=None,
        metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
    )

    # Parameters that control generation acceleration powered by vLLM
    use_vllm: Optional[bool] = field(
        default=False,
        metadata={
            "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
            "unused for training, as vLLM will require one for generation. vLLM must be installed "
            "(`pip install vllm`)."
        },
    )
    vllm_device: Optional[str] = field(
        default="auto",
        metadata={
            "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
            "will automatically select the next available GPU after the last one used for training. This assumes "
            "that training has not already occupied all available GPUs."
        },
    )
    vllm_gpu_memory_utilization: float = field(
        default=0.9,
        metadata={
            "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
            "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
            "size and thus improve the model's throughput. However, if the value is too high, it may cause "
            "out-of-memory (OOM) errors during initialization."
        },
    )
    vllm_dtype: Optional[str] = field(
        default="auto",
        metadata={
            "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
            "determined based on the model configuration. Find the supported values in the vLLM documentation."
        },
    )
    vllm_max_model_len: Optional[int] = field(
        default=None,
        metadata={
            "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
            "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
            "context size, which might be much larger than the KV cache, leading to inefficiencies."
        },
    )
    vllm_enable_prefix_caching: Optional[bool] = field(
        default=True,
        metadata={
            "help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
            "the hardware support this feature."
        },
    )
    vllm_guided_decoding_regex: Optional[str] = field(
        default=None,
        metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
    )

    # Parameters that control the training
    learning_rate: float = field(
        default=1e-6,
        metadata={
            "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
            "`transformers.TrainingArguments`."
        },
    )
    beta: float = field(
        default=0.04,
        metadata={
            "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
            "training speed, but may be numerically unstable for long training runs."
        },
    )
    num_iterations: int = field(
        default=1,
        metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
    )
    epsilon: float = field(
        default=0.2,
        metadata={"help": "Epsilon value for clipping."},
    )
    epsilon_high: Optional[float] = field(
        default=None,
        metadata={
            "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
            "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
        },
    )
    reward_weights: Optional[list[float]] = field(
        default=None,
        metadata={
            "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
            "rewards are weighted equally with weight `1.0`."
        },
    )
    sync_ref_model: bool = field(
        default=False,
        metadata={
            "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
            "steps, using the `ref_model_mixup_alpha` parameter."
        },
    )
    ref_model_mixup_alpha: float = field(
        default=0.6,
        metadata={
            "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
            "previous reference policy during updates. The reference policy is updated according to the equation: "
            "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
        },
    )
    ref_model_sync_steps: int = field(
        default=512,
        metadata={
            "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
            "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
        },
    )

    # Parameters that control the logging
    log_completions: bool = field(
        default=False,
        metadata={
            "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
            "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
        },
    )

================================================
FILE: src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py
================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import textwrap
from collections import defaultdict
from typing import Any, Callable, Optional, Union, Sized
from qwen_vl_utils import process_vision_info
import torch
import torch.utils.data
import transformers
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
    AriaForConditionalGeneration,
    AriaProcessor,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoProcessor,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Qwen2VLForConditionalGeneration,
    Qwen2_5_VLForConditionalGeneration,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url
# from trl import GRPOTrainer

from accelerate.utils import is_peft_model, set_seed
import PIL.Image

import copy
from torch.utils.data import Sampler
import warnings

if is_peft_available():
    from peft import PeftConfig, get_peft_model

if is_wandb_available():
    import wandb

from vlm_modules.vlm_module import VLMBaseModule
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


class RepeatRandomSampler(Sampler):
    """
    Sampler that repeats the indices of a dataset in a structured manner.

    Args:
        data_source (`Sized`):
            Dataset to sample from.
        mini_repeat_count (`int`):
            Number of times to repeat each index per batch.
        batch_size (`int`, *optional*, defaults to `1`):
            Number of unique indices per batch.
        repeat_count (`int`, *optional*, defaults to `1`):
            Number of times to repeat the full sampling process.
        seed (`int` or `None`, *optional*, defaults to `None`):
            Random seed for reproducibility.
    """

    def __init__(
        self,
        data_source: Sized,
        mini_repeat_count: int,
        batch_size: int = 1,
        repeat_count: int = 1,
        seed: Optional[int] = None,
    ):
        self.data_source = data_source
        self.mini_repeat_count = mini_repeat_count
        self.batch_size = batch_size
        self.repeat_count = repeat_count
        self.num_samples = len(data_source)
        self.seed = seed
        self.generator = torch.Generator()
        if seed is not None:
            self.generator.manual_seed(seed)

    def __iter__(self):
        indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
        indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
        indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]

        for chunk in indexes:
            for _ in range(self.repeat_count):
                for index in chunk:
                    for _ in range(self.mini_repeat_count):
                        yield index

    def __len__(self) -> int:
        return self.num_samples * self.mini_repeat_count * self.repeat_count


class VLMGRPOTrainer(Trainer):
    """
    Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
    paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).

    Example:

    ```python
    from datasets import load_dataset
    from trl import GRPOTrainer

    dataset = load_dataset("trl-lib/tldr", split="train")

    trainer = GRPOTrainer(
        model="Qwen/Qwen2-0.5B-Instruct",
        reward_funcs="weqweasdas/RM-Gemma-2B",
        train_dataset=dataset,
    )

    trainer.train()
    ```

    Args:
        model (`Union[str, PreTrainedModel]`):
            Model to be trained. Can be either:

            - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
              a path to a *directory* containing model weights saved using
              [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
              loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
              in `args.model_init_kwargs`.
            - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
        reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
            Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
            functions with the prompts and completions and sum the rewards. Can be either:

            - A single reward function, such as:
                - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
                path to a *directory* containing model weights saved using
                [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
                using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
                keyword arguments in `args.model_init_kwargs`.
                - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
                - A custom reward function: The function is provided with the prompts and the generated completions,
                  plus any additional columns in the dataset. It should return a list of rewards. For more details, see
                  [Using a custom reward function](#using-a-custom-reward-function).
            - A list of reward functions, where each item can independently be any of the above types. Mixing different
            types within the list (e.g., a string model ID and a custom reward function) is allowed.
        args ([`GRPOConfig`], *optional*, defaults to `None`):
            Configuration for this trainer. If `None`, a default configuration is used.
        train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
            Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
            ignored. The format of the samples can be either:

            - [Standard](dataset_formats#standard): Each sample contains plain text.
            - [Conversational](dataset_formats#conversational): Each s
Download .txt
gitextract_8_p77n13/

├── .gitignore
├── README.md
├── add_degradation/
│   ├── add_degradation.py
│   └── generate_degradation.py
├── app.py
├── demo.py
├── requirements.txt
├── run_scripts/
│   └── run_grpo_robust.sh
├── setup.sh
└── src/
    ├── eval/
    │   ├── test_od_r1.py
    │   ├── test_rec_baseline.py
    │   ├── test_rec_r1.py
    │   └── test_rec_r1_internvl.py
    └── open-r1-multimodal/
        ├── .gitignore
        ├── LICENSE
        ├── Makefile
        ├── configs/
        │   ├── ddp.yaml
        │   ├── zero2.yaml
        │   └── zero3.yaml
        ├── local_scripts/
        │   ├── zero2.json
        │   ├── zero3.json
        │   ├── zero3.yaml
        │   ├── zero3_offload.json
        │   └── zero_stage2_config.json
        ├── setup.cfg
        ├── setup.py
        └── src/
            └── open_r1/
                ├── __init__.py
                ├── configs.py
                ├── evaluate.py
                ├── generate.py
                ├── grpo_jsonl.py
                ├── qwen2_5vl_monkey_patch.py
                ├── trainer/
                │   ├── __init__.py
                │   ├── grpo_config.py
                │   └── grpo_trainer.py
                ├── utils/
                │   ├── __init__.py
                │   ├── callbacks.py
                │   ├── evaluation.py
                │   ├── hub.py
                │   ├── math.py
                │   └── pycocotools/
                │       ├── coco.py
                │       └── cocoeval.py
                └── vlm_modules/
                    ├── __init__.py
                    ├── qwen_module.py
                    └── vlm_module.py
Download .txt
SYMBOL INDEX (186 symbols across 24 files)

FILE: add_degradation/add_degradation.py
  function motion_blur (line 6) | def motion_blur(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function lens_blur (line 21) | def lens_blur(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function gaussian_noise (line 43) | def gaussian_noise(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function block_exchange (line 54) | def block_exchange(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function jpeg_compression (line 80) | def jpeg_compression(img: np.ndarray, intensity: float = 0.5) -> np.ndar...
  function mean_shift (line 97) | def mean_shift(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function color_diffusion (line 107) | def color_diffusion(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function sharpness_change (line 137) | def sharpness_change(img: np.ndarray, intensity: float = 0.5) -> np.ndar...
  function dark_illumination (line 152) | def dark_illumination(img: np.ndarray, intensity: float = 0.5) -> np.nda...
  function hsv_saturation (line 160) | def hsv_saturation(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function atmospheric_turbulence (line 171) | def atmospheric_turbulence(img: np.ndarray, intensity: float = 0.5) -> n...
  function dirty_lens (line 184) | def dirty_lens(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function scan_lines (line 223) | def scan_lines(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function graffiti (line 238) | def graffiti(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:
  function watermark_damage (line 279) | def watermark_damage(img: np.ndarray, intensity: float = 0.5) -> np.ndar...
  function lens_flare (line 302) | def lens_flare(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:

FILE: add_degradation/generate_degradation.py
  function apply_degradation_Benchmark (line 34) | def apply_degradation_Benchmark(image, method_name, intensity):
  function main (line 39) | def main():

FILE: app.py
  class ModelHandler (line 28) | class ModelHandler:
    method __init__ (line 29) | def __init__(self, model_path):
    method _load_model (line 35) | def _load_model(self):
    method predict (line 52) | def predict(self, message_dict, history, temperature, max_tokens):
  function create_chat_ui (line 160) | def create_chat_ui():

FILE: demo.py
  class ModelHandler (line 31) | class ModelHandler:
    method __init__ (line 32) | def __init__(self, model_path):
    method _load_model (line 38) | def _load_model(self):
    method predict (line 53) | def predict(self, question, image_path, temperature=DEFAULT_TEMPERATUR...
  function main (line 124) | def main():

FILE: src/eval/test_od_r1.py
  function extract_bbox_answer (line 13) | def extract_bbox_answer(content):
  function iou (line 28) | def iou(box1, box2):
  function load_model (line 41) | def load_model(model_path, device_map):
  function eval_od_r1 (line 56) | def eval_od_r1(

FILE: src/eval/test_rec_baseline.py
  function setup_distributed (line 20) | def setup_distributed():
  function extract_bbox_answer (line 62) | def extract_bbox_answer(content):
  function iou (line 72) | def iou(box1, box2):

FILE: src/eval/test_rec_r1.py
  function setup_distributed (line 20) | def setup_distributed():
  function extract_bbox_answer (line 67) | def extract_bbox_answer(content):
  function iou (line 80) | def iou(box1, box2):

FILE: src/eval/test_rec_r1_internvl.py
  function setup_distributed (line 18) | def setup_distributed():
  function extract_bbox_answer (line 70) | def extract_bbox_answer(content):
  function iou (line 83) | def iou(box1, box2):
  function process_vision_info (line 96) | def process_vision_info(batch_messages):

FILE: src/open-r1-multimodal/setup.py
  function deps_list (line 80) | def deps_list(*pkgs):

FILE: src/open-r1-multimodal/src/open_r1/configs.py
  class GRPOConfig (line 24) | class GRPOConfig(trl.GRPOConfig):
  class SFTConfig (line 54) | class SFTConfig(trl.SFTConfig):

FILE: src/open-r1-multimodal/src/open_r1/evaluate.py
  function prompt_fn (line 37) | def prompt_fn(line, task_name: str = None):

FILE: src/open-r1-multimodal/src/open_r1/generate.py
  function build_distilabel_pipeline (line 22) | def build_distilabel_pipeline(

FILE: src/open-r1-multimodal/src/open_r1/grpo_jsonl.py
  function initialize_tokenizer (line 60) | def initialize_tokenizer(model_path):
  class GRPOScriptArguments (line 68) | class GRPOScriptArguments(ScriptArguments):
  function extract_choice (line 119) | def extract_choice(text):
  function evaluate_answer_similarity (line 165) | def evaluate_answer_similarity(student_answer, ground_truth):
  function llm_reward (line 190) | def llm_reward(content, sol, **kwargs):
  function mcq_reward (line 200) | def mcq_reward(content, sol, **kwargs):
  function yes_no_reward (line 219) | def yes_no_reward(content, sol, **kwargs):
  function calculate_map (line 241) | def calculate_map(pred_bbox_list, gt_bbox_list, score_type=0):
  function map_reward (line 300) | def map_reward(content, sol, length_reward=False, score_type=0, **kwargs):
  function od_reward (line 354) | def od_reward(content, sol, score_type=0, **kwargs):
  function odLength_reward (line 388) | def odLength_reward(content, sol, **kwargs):
  function iou (line 421) | def iou(box1, box2):
  function detection_score (line 433) | def detection_score(content, sol, iou_threshold=0.5, alpha=0.7, beta=0.0...
  function cosine_reward (line 567) | def cosine_reward(content, tokenizer, acc_reward, **kwargs):
  function repetition_reward (line 594) | def repetition_reward(content, **kwargs):
  function repetition_rewards (line 688) | def repetition_rewards(completions, solution, **kwargs):
  function cosine_rewards (line 715) | def cosine_rewards(completions, solution, **kwargs):
  function numeric_reward (line 747) | def numeric_reward(content, sol, **kwargs):
  function math_reward (line 755) | def math_reward(content, sol, **kwargs):
  function clean_text (line 759) | def clean_text(text, exclue_chars=['\n', '\r']):
  function all_match_reward (line 778) | def all_match_reward(content, sol, **kwargs):
  function default_accuracy_reward (line 783) | def default_accuracy_reward(content, sol, **kwargs):
  function accuracy_reward (line 828) | def accuracy_reward(completions, solution, **kwargs):
  function format_reward (line 883) | def format_reward(completions, **kwargs):
  class GRPOModelConfig (line 909) | class GRPOModelConfig(ModelConfig):
  function get_vlm_module (line 920) | def get_vlm_module(model_name_or_path):
  function main (line 926) | def main(script_args, training_args, model_args):

FILE: src/open-r1-multimodal/src/open_r1/qwen2_5vl_monkey_patch.py
  function qwen2_5vl_vision_flash_attn_forward (line 6) | def qwen2_5vl_vision_flash_attn_forward(
  function monkey_patch_qwen2_5vl_flash_attn (line 43) | def monkey_patch_qwen2_5vl_flash_attn():
  function qwen2_5vl_forward (line 52) | def qwen2_5vl_forward(
  function monkey_patch_qwen2_5vl_forward (line 213) | def monkey_patch_qwen2_5vl_forward():
  function weigths_only_load (line 219) | def weigths_only_load(self, path: str, map_location=None):
  function monkey_patch_torch_load (line 225) | def monkey_patch_torch_load():

FILE: src/open-r1-multimodal/src/open_r1/trainer/grpo_config.py
  class GRPOConfig (line 22) | class GRPOConfig(TrainingArguments):

FILE: src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py
  class RepeatRandomSampler (line 69) | class RepeatRandomSampler(Sampler):
    method __init__ (line 86) | def __init__(
    method __iter__ (line 104) | def __iter__(self):
    method __len__ (line 115) | def __len__(self) -> int:
  class VLMGRPOTrainer (line 119) | class VLMGRPOTrainer(Trainer):
    method __init__ (line 203) | def __init__(
    method _enable_gradient_checkpointing (line 471) | def _enable_gradient_checkpointing(self, model: PreTrainedModel, args:...
    method _set_signature_columns_if_needed (line 502) | def _set_signature_columns_if_needed(self):
    method _get_per_token_logps (line 512) | def _get_per_token_logps(self, model, input_ids, attention_mask, **cus...
    method _prepare_inputs (line 525) | def _prepare_inputs(self, inputs):
    method _get_key_from_inputs (line 529) | def _get_key_from_inputs(self, x, key):
    method _generate_and_score_completions (line 537) | def _generate_and_score_completions(self, inputs: dict[str, Union[torc...
    method compute_loss (line 720) | def compute_loss(self, model, inputs, return_outputs=False, num_items_...
    method log (line 780) | def log(self, logs: dict[str, float], start_time: Optional[float] = No...
    method create_model_card (line 789) | def create_model_card(
    method _get_train_sampler (line 847) | def _get_train_sampler(self) -> Sampler:
    method _get_eval_sampler (line 863) | def _get_eval_sampler(self, eval_dataset) -> Sampler:

FILE: src/open-r1-multimodal/src/open_r1/utils/callbacks.py
  function is_slurm_available (line 28) | def is_slurm_available() -> bool:
  class DummyConfig (line 37) | class DummyConfig:
    method __init__ (line 38) | def __init__(self, **kwargs):
  class PushToHubRevisionCallback (line 43) | class PushToHubRevisionCallback(TrainerCallback):
    method __init__ (line 44) | def __init__(self, model_config) -> None:
    method on_save (line 47) | def on_save(self, args: TrainingArguments, state: TrainerState, contro...
  function get_callbacks (line 79) | def get_callbacks(train_config, model_config) -> List[TrainerCallback]:

FILE: src/open-r1-multimodal/src/open_r1/utils/evaluation.py
  function register_lighteval_task (line 26) | def register_lighteval_task(
  function get_lighteval_tasks (line 55) | def get_lighteval_tasks():
  function run_lighteval_job (line 62) | def run_lighteval_job(
  function run_benchmark_jobs (line 93) | def run_benchmark_jobs(training_args: Union["SFTConfig", "GRPOConfig"], ...

FILE: src/open-r1-multimodal/src/open_r1/utils/hub.py
  function push_to_hub_revision (line 39) | def push_to_hub_revision(training_args: SFTConfig | GRPOConfig, extra_ig...
  function check_hub_revision_exists (line 70) | def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig):
  function get_param_count_from_repo_id (line 88) | def get_param_count_from_repo_id(repo_id: str) -> int:
  function get_gpu_count_for_vllm (line 120) | def get_gpu_count_for_vllm(model_name: str, revision: str = "main", num_...

FILE: src/open-r1-multimodal/src/open_r1/utils/math.py
  function compute_score (line 2) | def compute_score(solution_str, ground_truth) -> float:
  function remove_boxed (line 25) | def remove_boxed(s):
  function last_boxed_only_string (line 38) | def last_boxed_only_string(string):
  function is_equiv (line 68) | def is_equiv(str1, str2, verbose=False):
  function fix_fracs (line 86) | def fix_fracs(string):
  function fix_a_slash_b (line 118) | def fix_a_slash_b(string):
  function remove_right_units (line 133) | def remove_right_units(string):
  function fix_sqrt (line 143) | def fix_sqrt(string):
  function strip_string (line 158) | def strip_string(string):

FILE: src/open-r1-multimodal/src/open_r1/utils/pycocotools/coco.py
  function _isArrayLike (line 20) | def _isArrayLike(obj):
  class COCO (line 24) | class COCO:
    method __init__ (line 25) | def __init__(self, annotation_file=None):
    method createIndex (line 47) | def createIndex(self):
    method info (line 78) | def info(self):
    method getAnnIds (line 86) | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
    method getCatIds (line 114) | def getCatIds(self, catNms=[], supNms=[], catIds=[]):
    method getImgIds (line 136) | def getImgIds(self, imgIds=[], catIds=[]):
    method loadAnns (line 157) | def loadAnns(self, ids=[]):
    method loadCats (line 168) | def loadCats(self, ids=[]):
    method loadImgs (line 179) | def loadImgs(self, ids=[]):
    method showAnns (line 190) | def showAnns(self, anns, draw_bbox=False):
    method loadRes (line 262) | def loadRes(self, resFile):
    method download (line 323) | def download(self, tarDir = None, imgIds = [] ):
    method loadNumpyAnnotations (line 347) | def loadNumpyAnnotations(self, data):
    method annToRLE (line 370) | def annToRLE(self, ann):
    method annToMask (line 391) | def annToMask(self, ann):

FILE: src/open-r1-multimodal/src/open_r1/utils/pycocotools/cocoeval.py
  class COCOeval (line 8) | class COCOeval:
    method __init__ (line 58) | def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
    method _prepare (line 82) | def _prepare(self):
    method evaluate (line 119) | def evaluate(self):
    method computeIoU (line 161) | def computeIoU(self, imgId, catId):
    method computeOks (line 190) | def computeOks(self, imgId, catId):
    method evaluateImg (line 233) | def evaluateImg(self, imgId, catId, aRng, maxDet):
    method accumulate (line 313) | def accumulate(self, p = None):
    method summarize (line 420) | def summarize(self):
    method __str__ (line 493) | def __str__(self):
  class Params (line 496) | class Params:
    method setDetParams (line 500) | def setDetParams(self):
    method setKpParams (line 511) | def setKpParams(self):
    method __init__ (line 523) | def __init__(self, iouType='segm'):

FILE: src/open-r1-multimodal/src/open_r1/vlm_modules/qwen_module.py
  class Qwen2VLModule (line 11) | class Qwen2VLModule(VLMBaseModule):
    method __init__ (line 12) | def __init__(self):
    method get_vlm_key (line 15) | def get_vlm_key(self):
    method get_model_class (line 18) | def get_model_class(self, model_id: str, model_init_kwargs: dict):
    method post_model_init (line 27) | def post_model_init(self, model, processing_class):
    method get_processing_class (line 30) | def get_processing_class(self):
    method get_vision_modules_keywords (line 33) | def get_vision_modules_keywords(self):
    method get_custom_multimodal_keywords (line 36) | def get_custom_multimodal_keywords(self):
    method get_non_generate_params (line 39) | def get_non_generate_params(self):
    method get_custom_processing_keywords (line 42) | def get_custom_processing_keywords(self):
    method prepare_prompt (line 45) | def prepare_prompt(self, processing_class, inputs: dict[str, Union[tor...
    method prepare_model_inputs (line 49) | def prepare_model_inputs(self, processing_class, prompts_text, images,...
    method get_question_template (line 71) | def get_question_template(task_type: str):
    method format_reward_rec (line 91) | def format_reward_rec(completions, **kwargs):
    method format_reward_robust (line 111) | def format_reward_robust(completions, **kwargs):
    method type_reward (line 131) | def type_reward(completions, solution, **kwargs):
    method accuracy_reward (line 191) | def accuracy_reward(completions, solution, **kwargs):
    method length_reward (line 216) | def length_reward(completions, solution, **kwargs):
    method iou_reward (line 241) | def iou_reward(completions, solution, **kwargs):
    method select_reward_func (line 293) | def select_reward_func(func: str, task_type: str):

FILE: src/open-r1-multimodal/src/open_r1/vlm_modules/vlm_module.py
  class VLMBaseModule (line 6) | class VLMBaseModule(ABC):
    method __init__ (line 7) | def __init__(self):
    method get_vlm_key (line 11) | def get_vlm_key(self):
    method get_model_class (line 15) | def get_model_class(self, model_id: str, model_init_kwargs: dict):
    method post_model_init (line 18) | def post_model_init(self, model, processing_class):
    method is_embeds_input (line 21) | def is_embeds_input(self):
    method get_processing_class (line 25) | def get_processing_class(self):
    method get_vision_modules_keywords (line 29) | def get_vision_modules_keywords(self):
    method get_custom_multimodal_keywords (line 33) | def get_custom_multimodal_keywords(self):
    method get_non_generate_params (line 37) | def get_non_generate_params(self):
    method get_custom_processing_keywords (line 41) | def get_custom_processing_keywords(self):
    method prepare_prompt (line 45) | def prepare_prompt(self, processing_class, inputs: dict[str, Union[tor...
    method prepare_model_inputs (line 49) | def prepare_model_inputs(self, processing_class, prompts_text, images,...
Condensed preview — 45 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (322K chars).
[
  {
    "path": ".gitignore",
    "chars": 2733,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "README.md",
    "chars": 8351,
    "preview": "<div align=\"center\">\n\n# [AAAI 2026 Oral] Robust-R1: Degradation-Aware Reasoning for Robust Visual Understanding\nThis is "
  },
  {
    "path": "add_degradation/add_degradation.py",
    "chars": 11117,
    "preview": "import cv2\nimport numpy as np\nimport random\n\n\ndef motion_blur(img: np.ndarray, intensity: float = 0.5) -> np.ndarray:\n  "
  },
  {
    "path": "add_degradation/generate_degradation.py",
    "chars": 3905,
    "preview": "import add_degradation\nimport cv2\nimport os\nimport numpy as np\nimport argparse\n\nDEGRADATION_CONFIG = {\n    'capture': {\n"
  },
  {
    "path": "app.py",
    "chars": 11440,
    "preview": "import gradio as gr\nimport os\nimport torch\nfrom transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor\nfr"
  },
  {
    "path": "demo.py",
    "chars": 6820,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nCLI Demo for Robust-R1: Visual Question Answering with Degradation-Aware Reasoning.\n\"\"\"\n\nimpo"
  },
  {
    "path": "requirements.txt",
    "chars": 208,
    "preview": "torch>=2.5.1\ntransformers==4.49.0\ngradio>=4.0.0\nqwen-vl-utils>=0.0.1\naccelerate>=1.2.1\nsentencepiece>=0.1.99\npillow\nsafe"
  },
  {
    "path": "run_scripts/run_grpo_robust.sh",
    "chars": 2057,
    "preview": "PROJECT_ROOT=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )/..\" && pwd )\"\nexport REPO_HOME=\"${PROJECT_ROOT}\"\necho \"REPO_HOME: "
  },
  {
    "path": "setup.sh",
    "chars": 704,
    "preview": "# conda create -n vlm-r1 python=3.11 \n# conda activate vlm-r1\n\n# Install the packages in open-r1-multimodal .\ncd src/ope"
  },
  {
    "path": "src/eval/test_od_r1.py",
    "chars": 6711,
    "preview": "import re\nimport os\nimport json\nimport torch\nimport random\n\nfrom tqdm import tqdm\nfrom pprint import pprint\nfrom qwen_vl"
  },
  {
    "path": "src/eval/test_rec_baseline.py",
    "chars": 7755,
    "preview": "from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor\nfrom qwen_vl_utils import proc"
  },
  {
    "path": "src/eval/test_rec_r1.py",
    "chars": 7847,
    "preview": "from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor\nfrom qwen_vl_utils import proc"
  },
  {
    "path": "src/eval/test_rec_r1_internvl.py",
    "chars": 7883,
    "preview": "import torch\nimport json\nfrom tqdm import tqdm\nimport re\nimport os\nfrom pprint import pprint\nimport random\nfrom transfor"
  },
  {
    "path": "src/open-r1-multimodal/.gitignore",
    "chars": 3474,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "src/open-r1-multimodal/LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "src/open-r1-multimodal/Makefile",
    "chars": 486,
    "preview": ".PHONY: style quality\n\n# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes"
  },
  {
    "path": "src/open-r1-multimodal/configs/ddp.yaml",
    "chars": 319,
    "preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\ngpu_ids: all\nmachine_ran"
  },
  {
    "path": "src/open-r1-multimodal/configs/zero2.yaml",
    "chars": 467,
    "preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  deepspeed_multinode_launcher: standard\n  offload_opt"
  },
  {
    "path": "src/open-r1-multimodal/configs/zero3.yaml",
    "chars": 496,
    "preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  deepspeed_multinode_launcher: standard\n  offload_opt"
  },
  {
    "path": "src/open-r1-multimodal/local_scripts/zero2.json",
    "chars": 1028,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_"
  },
  {
    "path": "src/open-r1-multimodal/local_scripts/zero3.json",
    "chars": 1099,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_"
  },
  {
    "path": "src/open-r1-multimodal/local_scripts/zero3.yaml",
    "chars": 498,
    "preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  deepspeed_multinode_launcher: standard\n  offload_opt"
  },
  {
    "path": "src/open-r1-multimodal/local_scripts/zero3_offload.json",
    "chars": 1288,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_"
  },
  {
    "path": "src/open-r1-multimodal/local_scripts/zero_stage2_config.json",
    "chars": 671,
    "preview": "{\n  \"zero_optimization\": {\n    \"stage\": 2,\n    \"allgather_partitions\": true,\n    \"allgather_bucket_size\": 1e8,\n    \"over"
  },
  {
    "path": "src/open-r1-multimodal/setup.cfg",
    "chars": 696,
    "preview": "[isort]\ndefault_section = FIRSTPARTY\nensure_newline_before_comments = True\nforce_grid_wrap = 0\ninclude_trailing_comma = "
  },
  {
    "path": "src/open-r1-multimodal/setup.py",
    "chars": 4812,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/configs.py",
    "chars": 3111,
    "preview": "# coding=utf-8\n# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Versio"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/evaluate.py",
    "chars": 2453,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/generate.py",
    "chars": 4806,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/grpo_jsonl.py",
    "chars": 42760,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/qwen2_5vl_monkey_patch.py",
    "chars": 11560,
    "preview": "\n# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------\nfr"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/trainer/__init__.py",
    "chars": 106,
    "preview": "from .grpo_trainer import VLMGRPOTrainer\nfrom .grpo_config import GRPOConfig\n\n__all__ = [\"VLMGRPOTrainer\"]"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/trainer/grpo_config.py",
    "chars": 18505,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py",
    "chars": 44556,
    "preview": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/utils/callbacks.py",
    "chars": 3139,
    "preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/utils/evaluation.py",
    "chars": 4223,
    "preview": "import subprocess\nfrom typing import TYPE_CHECKING, Dict, Union\n\nfrom .hub import get_gpu_count_for_vllm, get_param_coun"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/utils/hub.py",
    "chars": 5456,
    "preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/utils/math.py",
    "chars": 6134,
    "preview": "from math_verify import parse, verify\ndef compute_score(solution_str, ground_truth) -> float:\n    retval = 0.\n    \n    i"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/utils/pycocotools/coco.py",
    "chars": 16525,
    "preview": "import json\nimport time\nimport matplotlib.pyplot as plt\nfrom matplotlib.collections import PatchCollection\nfrom matplotl"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/utils/pycocotools/cocoeval.py",
    "chars": 24124,
    "preview": "import numpy as np\nimport datetime\nimport time\nfrom collections import defaultdict\nfrom pycocotools import mask as maskU"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/vlm_modules/__init__.py",
    "chars": 122,
    "preview": "from .vlm_module import VLMBaseModule\nfrom .qwen_module import Qwen2VLModule\n\n__all__ = [\"VLMBaseModule\", \"Qwen2VLModule"
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/vlm_modules/qwen_module.py",
    "chars": 14850,
    "preview": "from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor\nfrom typing "
  },
  {
    "path": "src/open-r1-multimodal/src/open_r1/vlm_modules/vlm_module.py",
    "chars": 1140,
    "preview": "from abc import ABC, abstractmethod\nfrom typing import Dict, Any, Union\nimport torch\n\n\nclass VLMBaseModule(ABC):\n    def"
  }
]

About this extraction

This page contains the full source code of the jqtangust/Robust-R1 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 45 files (300.6 KB), approximately 75.3k tokens, and a symbol index with 186 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!