Repository: jqtangust/Robust-R1 Branch: main Commit: 02f2c9d3e785 Files: 45 Total size: 300.6 KB Directory structure: gitextract_8_p77n13/ ├── .gitignore ├── README.md ├── add_degradation/ │ ├── add_degradation.py │ └── generate_degradation.py ├── app.py ├── demo.py ├── requirements.txt ├── run_scripts/ │ └── run_grpo_robust.sh ├── setup.sh └── src/ ├── eval/ │ ├── test_od_r1.py │ ├── test_rec_baseline.py │ ├── test_rec_r1.py │ └── test_rec_r1_internvl.py └── open-r1-multimodal/ ├── .gitignore ├── LICENSE ├── Makefile ├── configs/ │ ├── ddp.yaml │ ├── zero2.yaml │ └── zero3.yaml ├── local_scripts/ │ ├── zero2.json │ ├── zero3.json │ ├── zero3.yaml │ ├── zero3_offload.json │ └── zero_stage2_config.json ├── setup.cfg ├── setup.py └── src/ └── open_r1/ ├── __init__.py ├── configs.py ├── evaluate.py ├── generate.py ├── grpo_jsonl.py ├── qwen2_5vl_monkey_patch.py ├── trainer/ │ ├── __init__.py │ ├── grpo_config.py │ └── grpo_trainer.py ├── utils/ │ ├── __init__.py │ ├── callbacks.py │ ├── evaluation.py │ ├── hub.py │ ├── math.py │ └── pycocotools/ │ ├── coco.py │ └── cocoeval.py └── vlm_modules/ ├── __init__.py ├── qwen_module.py └── vlm_module.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv Pipfile.lock # poetry poetry.lock # pdm .pdm.toml # PEP 582 __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ .conda/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # IDEs .vscode/ .idea/ *.swp *.swo *~ .DS_Store *.sublime-project *.sublime-workspace # Model checkpoints and outputs checkpoints/ outputs/ runs/ *.pt *.pth *.ckpt *.safetensors *.bin *.h5 *.hdf5 *.onnx *.pb *.tflite *.pkl *.pickle # Training logs and outputs logs/ wandb/ tensorboard/ tb_logs/ *.log *.out *.err # Data files (uncomment if you don't want to track large data files) # data/ # datasets/ # *.csv # *.json # *.jsonl # *.parquet # *.arrow # Image files (uncomment if you don't want to track large image datasets) # *.jpg # *.jpeg # *.gif # *.bmp # *.tiff # *.webp # Allow PNG files in assets directory !assets/**/*.png !assets/**/*.jpg !assets/**/*.jpeg # Temporary files *.tmp *.temp *.bak *.swp *.swo *~ # OS files .DS_Store .DS_Store? ._* .Spotlight-V100 .Trashes ehthumbs.db Thumbs.db desktop.ini # DeepSpeed deepspeed_logs/ deepspeed_info/ # Hugging Face cache .cache/ .huggingface/ transformers_cache/ # Local configuration files config.local.yaml config.local.yml *.local.* # Evaluation results eval_results/ results/ # JSON files (but keep config files in specific directories) *.json !**/local_scripts/*.json !**/config*.json !package.json !tsconfig.json !requirements.txt # Large files *.zip *.tar *.tar.gz *.rar *.7z # Node modules (if any) node_modules/ # Compiled models *.model *.weights ================================================ FILE: README.md ================================================
# [AAAI 2026 Oral] Robust-R1: Degradation-Aware Reasoning for Robust Visual Understanding This is the official repository for Robust-R1. [Jiaqi Tang^](https://jqt.me/), [Jianmin Chen^](https://github.com/Ch921-cell), \ [Wei Wei**](https://scholar.google.com/citations?hl=zh-CN&user=v8KMYlwAAAAJ), [Xiaogang Xu](https://xuxiaogang.com/), [Runtao Liu](https://scholar.google.com/citations?hl=zh-CN&user=YHTvXF4AAAAJ), [Xiangyu Wu](https://scholar.google.com/citations?user=R0GjVWIAAAAJ&hl=en), [Qipeng Xie](), [Jiafei Wu](), [Lei Zhang](https://scholar.google.com/citations?hl=zh-CN&user=0Kg6Gi4AAAAJ) and \ [Qifeng Chen*](https://cqf.io) ^: Equal contribution. *: Corresponding Author. **: Co-corresponding Author. [![Paper](https://img.shields.io/badge/cs.CV-Paper-b31b1b?style=flat&logo=arxiv&logoColor=white)](https://huggingface.co/papers/2512.17532) [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-ffd21e)](https://huggingface.co/Jiaqi-hkust/Robust-R1) [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Data-ffd21e)](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) [![made-for-VSCode](https://img.shields.io/badge/Made%20for-VSCode-1f425f.svg)](https://code.visualstudio.com/) [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
## 📰 **News** - **[2025-12-23]** 🔥 Online demo is now available at [HF Space](https://huggingface.co/spaces/Jiaqi-hkust/Robust-R1). - **[2025-12-23]** 🔥 We release the [Code](https://github.com/jqtangust/Robust-R1), [Models](https://huggingface.co/Jiaqi-hkust/Robust-R1), and [Dataset](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) on HuggingFace. - **[2025-12-22]** ✅ Our paper is now available on [arXiv](https://arxiv.org/abs/your-paper-id). - **[2025-11-08]** 🚀 Our paper is accepted by **AAAI 2026 Oral**. --- ## 🔭 **Motivation** - 🚩 **Limited Interpretability**: Lack of explicit mechanisms to diagnose degradation impacts on original semantic information. - 🚩 **Isolated Optimization**: Neglect of the degradation propagation relation between the visual encoder and large language model.
Method Overview
--- ## 🛠️ **Installation** - **Clone the repository:** ```bash git clone https://github.com/jqtangust/Robust-R1.git cd Robust-R1 ``` - **Create environment:** ```bash conda create -n robust_r1 python=3.10 conda activate robust_r1 bash setup.sh ``` --- ### 🏰 **Pretrained and Fine-tuned Model** - The following checkpoints are utilized to run Robust-R1: | Checkpoint | Link | Note | |:---------:|:----:|:----:| | Qwen2.5-VL-Base | [link](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) | Used as initial weights for training. | | **Robust-R1-SFT** | [link](https://huggingface.co/Jiaqi-hkust/Robust-R1-SFT) | Fine-tuned on [Robust-R1 dataset](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) | | **Robust-R1-RL** | [link](https://huggingface.co/Jiaqi-hkust/Robust-R1-RL) | Fine-tuned with reinforcement learning on [Robust-R1 dataset](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) | --- ## ⏳ **Demo** ### 🖥️ CLI Demo - Run the command-line demo with a question: ```bash # if you use local weight export MODEL_PATH="your_model_name_or_path" python demo.py "What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n" ``` ### 🌐 GUI Demo - Set the model path as an environment variable and run the demo: ```bash # if you use local weight export MODEL_PATH="your_model_name_or_path" python app.py ``` - The demo will be available at `http://localhost:7860` by default. - GUI [Online Demo](https://huggingface.co/spaces/Jiaqi-hkust/Robust-R1).
Robust-R1 Demo
--- ## 🧠 **Training** ### 🎓 Supervised Fine-Tuning We employ [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) for supervised fine-tuning of the base model. 1. Clone the repository and install required dependencies: ```bash git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` 2. Download the base model [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct). 3. Prepare the training data and configuration files: - Download the [Robust images](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) and unzip it. - Modify the configuration files in the `LLaMA-Factory/data` directory. 4. Configure the training YAML file with your local paths (model path, data path, output directory.). 5. Run the training command to train the SFT model: ```bash llamafactory-cli train examples/train_full/qwen2_5_vl_full_sft.yaml ``` ### 🎓 Reinforcement Learning 1. Download [Robust images](https://huggingface.co/datasets/Jiaqi-hkust/Robust-R1) and unzip it in `Robust-R1/dataset`. 2. Prepare the training data file (train.jsonl) and organize the image folders. 3. Download the SFT model checkpoint from [Robust-R1-SFT](https://huggingface.co/Jiaqi-hkust/Robust-R1-SFT) or use your own trained SFT model. 4. Replace the following part in the [run_scripts/run_grpo_robust.sh](run_scripts/run_grpo_robust.sh) file with your own paths: ```bash data_paths="Robust-R1/data/train.jsonl" image_folders="Robust-R1/data/train_images" model_path="your_model_name_or_path" ``` 5. Run the script: ```bash bash run_scripts/run_grpo_robust.sh ``` --- ## 📊 **Evaluation** We use [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) for anti-degradation evaluation. 1. Clone the VLMEvalKit repository and install dependencies: ```bash git clone https://github.com/open-compass/VLMEvalKit.git cd VLMEvalKit pip install -e . ``` 2. Prepare the evaluation datasets according to VLMEvalKit requirements. 3. **Image Degradation Pipeline**: Generate corrupted images for robustness evaluation. We provide an image degradation pipeline for generating corrupted images to evaluate model robustness. Navigate to the degradation pipeline directory and process images: ```bash cd add_degradation python generate_pipeline_open_source.py --input_dir --output_base_dir --dataset_name --verbose ``` The script will generate three output directories with different degradation intensities for each image. 4. Configure the model path and evaluation settings in the VLMEvalKit configuration file. 5. Run the evaluation command: ```bash python run.py --model --data ``` ### 🔬 R-Bench Evaluation For R-Bench evaluation, we use [R-Bench](https://github.com/Q-Future/R-Bench) to assess model performance under real-world corruptions. 1. Clone the R-Bench repository: ```bash git clone https://github.com/Q-Future/R-Bench.git ``` 2. Evaluate using VLMEvalKit with R-Bench dataset: ```bash cd VLMEvalKit python run.py --data R-Bench-Dis --model --verbose ``` 3. For full dataset evaluation, follow the R-Bench pipeline as described in the [R-Bench repository](https://github.com/Q-Future/R-Bench). --- ## ⭐️ Citation If you find Robust-R1 useful for your research and applications, please cite using this BibTeX: ``` latex @inproceedings{tang2025robustr1, title={Robust-R1: Degradation-Aware Reasoning for Robust Visual Understanding}, author={Tang, Jiaqi and Chen, Jianmin and Wei, Wei and Xu, Xiaogang and Liu, Runtao and Wu, Xiangyu and Xie, Qipeng and Wu, Jiafei and Zhang, Lei and Chen, Qifeng}, booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, year={2026} } ``` ## 🤝 Acknowledgements The work described in this paper was supported by a grant from the Research Grants Council of the Hong Kong Special Administrative Region, China (Project Reference Number: AoE/E-601/24-N). We also thank the authors of [VLM-R1](https://github.com/om-ai-lab/VLM-R1?tab=readme-ov-file), [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), and [R-Bench](https://github.com/Q-Future/R-Bench) for their contributions. ================================================ FILE: add_degradation/add_degradation.py ================================================ import cv2 import numpy as np import random def motion_blur(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") degree = max(5, int(intensity * 30)) angle = random.uniform(0, 360) M = cv2.getRotationMatrix2D((degree/2, degree/2), angle, 1) kernel = np.diag(np.ones(degree)) kernel = cv2.warpAffine(kernel, M, (degree, degree)) kernel /= np.sum(kernel) return cv2.filter2D(img, -1, kernel) def lens_blur(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") kernel_size = int(3 + intensity * 300) | 1 sigma = intensity * 20 kernel = cv2.getGaussianKernel(kernel_size, sigma) kernel = kernel @ kernel.T blurred = np.zeros_like(img, dtype=np.float32) for c in range(3): blurred[..., c] = cv2.filter2D(img[..., c].astype(np.float32), -1, kernel) result = cv2.addWeighted( img, 1 - intensity * 0.7, blurred.astype(np.uint8), intensity * 0.9, 0 ) return result def gaussian_noise(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") noise_std = intensity * 75 noise = np.random.normal(0, noise_std, img.shape) result = np.clip(img.astype(np.float32) + noise, 0, 255).astype(np.uint8) return result def block_exchange(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") h, w = img.shape[:2] block_size = min(32, int(5 + intensity * 30)) noisy_img = img.copy() num_exchanges = int(intensity * 35) for _ in range(num_exchanges): i1 = random.randint(0, h // block_size - 1) j1 = random.randint(0, w // block_size - 1) i2 = random.randint(0, h // block_size - 1) j2 = random.randint(0, w // block_size - 1) y1, x1 = i1 * block_size, j1 * block_size y2, x2 = i2 * block_size, j2 * block_size block1 = noisy_img[y1:y1+block_size, x1:x1+block_size].copy() noisy_img[y1:y1+block_size, x1:x1+block_size] = \ noisy_img[y2:y2+block_size, x2:x2+block_size] noisy_img[y2:y2+block_size, x2:x2+block_size] = block1 return noisy_img def jpeg_compression(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if not 0 <= intensity <= 1: raise ValueError("Intensity must be in range [0.0, 1.0]") if img is None: raise ValueError("Input image is None") quality = int(100 - intensity * 95) quality = max(5, min(100, quality)) encode_params = [int(cv2.IMWRITE_JPEG_QUALITY), quality] _, encimg = cv2.imencode('.jpg', img, encode_params) compressed_img = cv2.imdecode(encimg, cv2.IMREAD_COLOR) return compressed_img def mean_shift(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") spatial_radius = int(intensity * 40) color_radius = int(intensity * 40) return cv2.pyrMeanShiftFiltering(img, spatial_radius, color_radius) def color_diffusion(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") kernel_size = 3 + 2 * int(intensity * 20) sigma = intensity * 50 kernel = cv2.getGaussianKernel(kernel_size, sigma) kernel = kernel @ kernel.T * (intensity ** 2) diffused = np.zeros_like(img, dtype=np.float32) for c in range(3): diffused[..., c] = cv2.filter2D(img[..., c].astype(np.float32), -1, kernel) if intensity > 0.9: h, w = img.shape[:2] for _ in range(int(100 * intensity)): x, y = np.random.randint(0, w), np.random.randint(0, h) radius = np.random.randint(5, 20) cv2.circle(diffused, (x, y), radius, (np.random.randint(0, 255),) * 3, -1) result = cv2.addWeighted( img, max(0.1, 1 - intensity * 0.9), diffused.astype(np.uint8), min(0.9, intensity * 0.9), 0 ) return np.clip(result, 0, 255).astype(np.uint8) def sharpness_change(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") if intensity > 0: kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) * (intensity * 80) result = cv2.filter2D(img, -1, kernel) else: ksize = int(3 + abs(intensity) * 5) | 1 result = cv2.GaussianBlur(img, (ksize, ksize), 0) result = cv2.addWeighted(img, 0.7, result, 0.3, 0) return result def dark_illumination(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") result = (img * (1 - intensity ** 2)).clip(0, 255).astype(np.uint8) return result def hsv_saturation(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32) hsv[..., 1] *= (1 - intensity) result = cv2.cvtColor(hsv.clip(0, 255).astype(np.uint8), cv2.COLOR_HSV2BGR) return result def atmospheric_turbulence(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") h, w = img.shape[:2] x, y = np.meshgrid(np.arange(w), np.arange(h)) distortion = intensity * 40 * np.sin(y / 30 + intensity * 5) x_new = np.clip(x + distortion, 0, w - 1).astype(np.float32) y_new = np.clip(y + distortion * 0.7, 0, h - 1).astype(np.float32) return cv2.remap(img, x_new, y_new, cv2.INTER_LINEAR) def dirty_lens(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") h, w = img.shape[:2] dirt = np.zeros((h, w, 3), dtype=np.float32) if intensity > 0.1: for _ in range(int(10 * intensity)): center_x = random.randint(0, w) center_y = random.randint(0, h) cv2.ellipse(dirt, (center_x, center_y), (random.randint(150, 300), random.randint(100, 200)), angle=random.randint(0, 180), startAngle=0, endAngle=360, color=(50, 50, 50), thickness=-1) for _ in range(int(300 * intensity)): x = random.randint(0, w) y = random.randint(0, h) cv2.circle(dirt, (x, y), random.randint(4, 20), (random.randint(50, 100),) * 3, -1) if intensity > 0.5: for _ in range(int(5 * intensity)): x = random.randint(0, w) y = random.randint(0, h) cv2.circle(dirt, (x, y), random.randint(20, 50), (80, 80, 80), -1) cv2.circle(dirt, (x, y), random.randint(10, 30), (120, 120, 120), -1) dirt = cv2.GaussianBlur(dirt, (0, 0), 30) dirt = dirt.astype(np.uint8) result = cv2.addWeighted(img, 1 - 0.7 * intensity, dirt, 0.8 * intensity, 0) return np.clip(result, 0, 255).astype(np.uint8) def scan_lines(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") line_interval = max(3, int(20 / (intensity + 0.1))) line_width = max(5, int(7 * intensity)) result = img.copy() for i in range(0, img.shape[0], line_interval): end_line = min(i + line_width, img.shape[0]) result[i:end_line] = result[i:end_line] * 0.01 return result def graffiti(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if not 0 <= intensity <= 1: raise ValueError("Intensity must be in range [0.0, 1.0]") if img is None: raise ValueError("Input image is None") h, w = img.shape[:2] result = img.copy() for _ in range(int(10 * intensity)): color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) pt1 = (random.randint(0, w - 1), random.randint(0, h - 1)) pt2 = (random.randint(0, w - 1), random.randint(0, h - 1)) thickness = random.randint(1, max(1, int(5 * intensity))) cv2.line(result, pt1, pt2, color, thickness) if intensity > 0.55: texts = ["X", "FAKE", "COPY", "VOID", "COPYRIGHT", str(random.randint(1, 100))] text = random.choice(texts) font_scale = max(0.5, intensity * 5) thickness = max(1, int(font_scale)) text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0] text_width, text_height = text_size if h - 10 > text_height + 10: text_x = random.randint(0, max(1, w - text_width - 10)) text_y = random.randint(text_height + 10, h - 10) cv2.putText(result, text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 255), thickness) return result def watermark_damage(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") h, w = img.shape[:2] mask = np.zeros((h, w), dtype=np.float32) for _ in range(int(1 + intensity * 15)): x = random.randint(0, w - 50) y = random.randint(0, h - 50) cv2.rectangle(mask, (x, y), (x + random.randint(50, 200), y + random.randint(20, 80)), 1, -1) repaired = cv2.inpaint(img, (mask * 255).astype(np.uint8), 3, cv2.INPAINT_TELEA) result = cv2.addWeighted(img, 1 - intensity, repaired, intensity, 0) if intensity > 0.5: edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) result[edges > 0] = result[edges > 0] * 0.8 return result def lens_flare(img: np.ndarray, intensity: float = 0.5) -> np.ndarray: if img is None: raise ValueError("Input image is None") h, w = img.shape[:2] flare = np.zeros((h, w, 3), dtype=np.float32) num_flares = 3 + int(30 * intensity) for _ in range(num_flares): x = random.randint(0, w) y = random.randint(0, h) radius = random.randint(10, 50) color = np.array([255, 255, 235]) cv2.circle(flare, (x, y), radius, color.tolist(), -1) angle = random.uniform(0, 2 * np.pi) length = random.randint(30, 150) end_x = int(x + length * np.cos(angle)) end_y = int(y + length * np.sin(angle)) cv2.line(flare, (x, y), (end_x, end_y), color.tolist(), 2) flare = cv2.GaussianBlur(flare, (3, 3), 20 * intensity) result = cv2.addWeighted(img.astype(np.float32), 1, flare, 0.9 * intensity, 0) return np.clip(result, 0, 255).astype(np.uint8) ================================================ FILE: add_degradation/generate_degradation.py ================================================ import add_degradation import cv2 import os import numpy as np import argparse DEGRADATION_CONFIG = { 'capture': { 'lens_blur': {'weight': 20}, 'lens_flare': {'weight': 20}, 'motion_blur': {'weight': 20}, 'dirty_lens': {'weight': 20}, 'hsv_saturation': {'weight': 20} }, 'transmission': { 'jpeg_compression': {'weight': 25}, 'block_exchange': {'weight': 25}, 'mean_shift': {'weight': 25}, 'scan_lines': {'weight': 25} }, 'environment': { 'dark_illumination': {'weight': 25}, 'atmospheric_turbulence': {'weight': 25}, 'gaussian_noise': {'weight': 25}, 'color_diffusion': {'weight': 25} }, 'postprocessing': { 'sharpness_change': {'weight': 33}, 'graffiti': {'weight': 33}, 'watermark_damage': {'weight': 34} } } def apply_degradation_Benchmark(image, method_name, intensity): degradation_func = getattr(add_degradation, method_name) degraded_img = degradation_func(image, intensity) return degraded_img def main(): parser = argparse.ArgumentParser(description='Image degradation pipeline for robustness evaluation') parser.add_argument('--input_dir', type=str, default=os.getenv('INPUT_DIR', './data/images'), help='Input image directory path (can be set via INPUT_DIR environment variable)') parser.add_argument('--output_base_dir', type=str, default=os.getenv('OUTPUT_BASE_DIR', './data/output'), help='Base directory for output images (can be set via OUTPUT_BASE_DIR environment variable)') parser.add_argument('--dataset_name', type=str, default=os.getenv('DATASET_NAME', 'RealWorldQA'), help='Dataset name (used to generate output directory names)') args = parser.parse_args() folder_path = args.input_dir output_base_dir = args.output_base_dir dataset_name = args.dataset_name output_dirs = { 0.9: os.path.join(output_base_dir, f'{dataset_name}_Robust_100'), 0.45: os.path.join(output_base_dir, f'{dataset_name}_Robust_50'), 0.23: os.path.join(output_base_dir, f'{dataset_name}_Robust_25') } if not os.path.exists(folder_path): raise ValueError(f"Input directory does not exist: {folder_path}") for path in output_dirs.values(): os.makedirs(path, exist_ok=True) all_methods_with_weights = [] for category, methods in DEGRADATION_CONFIG.items(): for method_name, details in methods.items(): all_methods_with_weights.append((method_name, details['weight'])) method_names = [item[0] for item in all_methods_with_weights] weights = [item[1] for item in all_methods_with_weights] total_weight = sum(weights) probabilities = [w / total_weight for w in weights] num = 0 for filename in os.listdir(folder_path): if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): image_path = os.path.join(folder_path, filename) image = cv2.imread(image_path) if image is None: print(f"Warning: Could not read image {image_path}, skipping") num += 1 continue selected_method_name = np.random.choice(method_names, p=probabilities) for intensity, output_dir in output_dirs.items(): degraded_img = apply_degradation_Benchmark(image, selected_method_name, intensity) save_path = os.path.join(output_dir, filename) cv2.imwrite(save_path, degraded_img) num += 1 if num % 100 == 0: print(f"Processed {num} images") print("Processing completed!") if __name__ == '__main__': main() ================================================ FILE: app.py ================================================ import gradio as gr import os import torch from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import html sys_prompt = """First output the the types of degradations in image briefly in tags, and then output what effects do these degradation have on the image in tags, then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in tags, and then summarize the content of reasoning and the give the answer in tags, provides the user with the answer briefly in .""" project_dir = os.path.dirname(os.path.abspath(__file__)) temp_dir = os.path.join(project_dir, ".gradio_temp") os.makedirs(temp_dir, exist_ok=True) os.environ["GRADIO_TEMP_DIR"] = temp_dir MODEL_PATH = os.getenv("MODEL_PATH", "") if not MODEL_PATH: raise ValueError("MODEL_PATH environment variable must be set. Please set it to your model path.") print(f"==========================================") print(f"Initializing application...") print(f"==========================================") class ModelHandler: def __init__(self, model_path): self.model_path = model_path self.model = None self.processor = None self._load_model() def _load_model(self): try: print(f"⏳ Loading model weights, this may take a few minutes...") self.processor = AutoProcessor.from_pretrained(self.model_path) self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( self.model_path, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2" if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else "eager" ) print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Model loading failed: {e}") raise e def predict(self, message_dict, history, temperature, max_tokens): text = message_dict.get("text", "") files = message_dict.get("files", []) messages = [] if history: print(f"Processing {len(history)} previous messages from history") for msg in history: role = msg.get("role", "") content = msg.get("content", "") if role == "user": user_content = [] if isinstance(content, list): for item in content: if isinstance(item, str): if os.path.exists(item) or any(item.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']): user_content.append({"type": "image", "image": item}) else: user_content.append({"type": "text", "text": item}) elif isinstance(item, dict): user_content.append(item) elif isinstance(content, str): if content: user_content.append({"type": "text", "text": content}) if user_content: messages.append({"role": "user", "content": user_content}) elif role == "assistant": if isinstance(content, str) and content: messages.append({"role": "assistant", "content": content}) current_content = [] if files: for file_path in files: current_content.append({"type": "image", "image": file_path}) if text: sys_prompt_formatted = " ".join(sys_prompt.split()) full_text = f"{text}\n{sys_prompt_formatted}" current_content.append({"type": "text", "text": full_text}) if current_content: messages.append({"role": "user", "content": current_content}) print(f"Total messages for model: {len(messages)}") print(f"Message roles: {[m['role'] for m in messages]}") text_prompt = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor( text=[text_prompt], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ) inputs = inputs.to(self.model.device) generation_kwargs = dict( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True if temperature > 0 else False, ) try: print("Starting model generation...") with torch.no_grad(): generated_ids = self.model.generate(**generation_kwargs) input_length = inputs['input_ids'].shape[1] generated_ids = generated_ids[0][input_length:] print(f"Input length: {input_length}, Generated token count: {len(generated_ids)}") generated_text = self.processor.tokenizer.decode( generated_ids, skip_special_tokens=True ) print(f"Generation completed. Output length: {len(generated_text)}, Content preview: {repr(generated_text[:200])}") if generated_text and generated_text.strip(): print(f"Yielding generated text: {generated_text[:100]}...") yield generated_text else: warning_msg = "⚠️ No output generated. The model may not have produced any response." print(warning_msg) yield warning_msg except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error in model.generate: {error_details}") yield f"❌ Generation error: {str(e)}" return model_handler = ModelHandler(MODEL_PATH) def create_chat_ui(): custom_css = """ .gradio-container { font-family: 'Inter', sans-serif; } #chatbot { height: 650px !important; overflow-y: auto; } """ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="Robust-R1") as demo: with gr.Row(): gr.Markdown("# 🤖Robust-R1:Degradation-Aware Reasoning for Robust Visual Understanding") with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( elem_id="chatbot", label="Chat", type="messages", avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"), height=650 ) chat_input = gr.MultimodalTextbox( interactive=True, file_types=["image"], placeholder="Enter your question or upload an image...", show_label=False ) with gr.Column(scale=1): with gr.Group(): gr.Markdown("### ⚙️ Generation Config") temperature = gr.Slider( minimum=0.01, maximum=1.0, value=0.6, step=0.05, label="Temperature" ) max_tokens = gr.Slider( minimum=128, maximum=4096, value=1024, step=128, label="Max New Tokens" ) clear_btn = gr.Button("🗑️ Clear Context", variant="stop") gr.Markdown("---") gr.Markdown("### 📚 Examples") gr.Markdown("Click the examples below to quickly fill the input box and start a conversation") example_images_dir = os.path.join(project_dir, "assets") examples_config = [ ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", os.path.join(example_images_dir, "1.jpg")), ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", os.path.join(example_images_dir, "2.jpg")), ] example_data = [] for text, img_path in examples_config: if os.path.exists(img_path): example_data.append({"text": text, "files": [img_path]}) if example_data: gr.Examples( examples=example_data, inputs=chat_input, label="", examples_per_page=3 ) else: gr.Markdown("*No example images available, please manually upload images for testing*") async def respond(user_msg, history, temp, tokens): text = user_msg.get("text", "").strip() files = user_msg.get("files", []) user_content = list(files) if text: user_content.append(text) if not files and text: user_message = {"role": "user", "content": text} else: user_message = {"role": "user", "content": user_content} history.append(user_message) yield history, gr.MultimodalTextbox(value=None, interactive=False) history.append({"role": "assistant", "content": ""}) try: previous_history = history[:-2] if len(history) >= 2 else [] generated_text = "" for chunk in model_handler.predict(user_msg, previous_history, temp, tokens): generated_text = chunk safe_text = html.escape(generated_text) safe_text = generated_text.replace("<", "<").replace(">", ">") history[-1]["content"] = safe_text yield history, gr.MultimodalTextbox(interactive=False) except Exception as e: import traceback traceback.print_exc() history[-1]["content"] = f"❌ Inference error: {str(e)}" yield history, gr.MultimodalTextbox(interactive=True) yield history, gr.MultimodalTextbox(value=None, interactive=True) chat_input.submit( respond, inputs=[chat_input, chatbot, temperature, max_tokens], outputs=[chatbot, chat_input] ) def clear_history(): return [], None clear_btn.click(clear_history, outputs=[chatbot, chat_input]) return demo if __name__ == "__main__": demo = create_chat_ui() print(f"🚀 Service is starting, please visit: http://localhost:7862") demo.launch( server_name="0.0.0.0", server_port=7862, share=False, show_error=True, allowed_paths=[project_dir] ) ================================================ FILE: demo.py ================================================ #!/usr/bin/env python3 """ CLI Demo for Robust-R1: Visual Question Answering with Degradation-Aware Reasoning. """ import os import sys import torch import argparse from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info # Default model path - can be overridden by MODEL_PATH environment variable # Users can set MODEL_PATH to their local model path or HuggingFace model name DEFAULT_MODEL_PATH = "Jiaqi-hkust/Robust-R1-RL" # HuggingFace model name MODEL_PATH = os.getenv("MODEL_PATH", DEFAULT_MODEL_PATH) # Fixed image path for demo FIXED_IMAGE_PATH = "assets/1.jpg" SYS_PROMPT = """First output the the types of degradations in image briefly in tags, and then output what effects do these degradation have on the image in tags, then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in tags, and then summarize the content of reasoning and the give the answer in tags, provides the user with the answer briefly in .""" DEFAULT_TEMPERATURE = 0.6 DEFAULT_MAX_TOKENS = 1024 class ModelHandler: def __init__(self, model_path): self.model_path = model_path self.model = None self.processor = None self._load_model() def _load_model(self): try: print("Loading model, this may take a few minutes...") self.processor = AutoProcessor.from_pretrained(self.model_path) self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( self.model_path, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2" if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else "eager" ) print("Model loaded successfully!") except Exception as e: print(f"Model loading failed: {e}") raise e def predict(self, question, image_path, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS): """ Generate response for the given question and image. Args: question: User question image_path: Path to the image temperature: Generation temperature max_tokens: Maximum number of tokens to generate Returns: Generated text response """ sys_prompt_formatted = " ".join(SYS_PROMPT.split()) full_text = f"{question}\n{sys_prompt_formatted}" messages = [ { "role": "user", "content": [ {"type": "text", "text": full_text}, {"type": "image", "image": image_path}, ], } ] text_prompt = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor( text=[text_prompt], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ) inputs = inputs.to(self.model.device) generation_kwargs = dict( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True if temperature > 0 else False, ) try: print("Generating response...") with torch.no_grad(): generated_ids = self.model.generate(**generation_kwargs) input_length = inputs['input_ids'].shape[1] generated_ids = generated_ids[0][input_length:] generated_text = self.processor.tokenizer.decode( generated_ids, skip_special_tokens=True ) return generated_text except Exception as e: import traceback error_details = traceback.format_exc() print(f"Generation error: {error_details}") raise e def main(): parser = argparse.ArgumentParser( description="CLI Demo for Robust-R1: Visual Question Answering with Degradation-Aware Reasoning", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python demo.py "What type of vehicles are the people riding?" python demo.py "What is in the image?" --temperature 0.7 --max-tokens 2048 python demo.py "Your question" --image /path/to/image.jpg """ ) parser.add_argument( "question", type=str, help="Question to ask about the image" ) parser.add_argument( "--image", "-i", type=str, default=FIXED_IMAGE_PATH, help=f"Path to the input image (default: {FIXED_IMAGE_PATH})" ) parser.add_argument( "--temperature", "-t", type=float, default=DEFAULT_TEMPERATURE, help=f"Generation temperature (default: {DEFAULT_TEMPERATURE})" ) parser.add_argument( "--max-tokens", "-m", type=int, default=DEFAULT_MAX_TOKENS, help=f"Maximum number of tokens to generate (default: {DEFAULT_MAX_TOKENS})" ) parser.add_argument( "--model-path", type=str, default=MODEL_PATH, help=f"Model path or HuggingFace model name (default: {MODEL_PATH}). Can also be set via MODEL_PATH environment variable." ) args = parser.parse_args() if not os.path.exists(args.image): print(f"Error: Image file does not exist: {args.image}") sys.exit(1) print(f"Model path: {args.model_path}") print(f"Image path: {args.image}") print(f"Question: {args.question}") print(f"Temperature: {args.temperature}, Max tokens: {args.max_tokens}") print("-" * 80) model_handler = ModelHandler(args.model_path) try: response = model_handler.predict( question=args.question, image_path=args.image, temperature=args.temperature, max_tokens=args.max_tokens ) print("\n" + "=" * 80) print("Model Response:") print("=" * 80) print(response) print("=" * 80) except KeyboardInterrupt: print("\n\nUser interrupted") sys.exit(0) except Exception as e: print(f"\nError: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main() ================================================ FILE: requirements.txt ================================================ torch>=2.5.1 transformers==4.49.0 gradio>=4.0.0 qwen-vl-utils>=0.0.1 accelerate>=1.2.1 sentencepiece>=0.1.99 pillow safetensors>=0.3.3 huggingface-hub>=0.19.2,<1.0 einops>=0.8.0 packaging>=23.0 numpy>=1.21.0 ================================================ FILE: run_scripts/run_grpo_robust.sh ================================================ PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" export REPO_HOME="${PROJECT_ROOT}" echo "REPO_HOME: $REPO_HOME" # Change the data_paths and image_folders to your own data data_paths="your_data_path" image_folders="your_images_folder" model_path="your_model_name_or_path" is_reward_customized_from_vlm_module=True echo "data_paths: $data_paths" echo "image_folders: $image_folders" export EXP_NAME="your_experiment_name" # TODO: change this to your own experiment name TASK_TYPE="robust" cd ${REPO_HOME}/src/open-r1-multimodal export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL # create the run directory and log file mkdir -p ${REPO_HOME}/runs/${EXP_NAME}/log export LOG_PATH="${REPO_HOME}/runs/${EXP_NAME}/log/debug_log.$(date +%Y-%m-%d-%H-%M-%S).txt" # MAX_STEPS=1200 # TODO: change this to your own max steps # export WANDB_DISABLED=true torchrun --nproc_per_node="8" \ --nnodes="1" \ --node_rank="0" \ --master_addr="127.0.0.1" \ --master_port="12352" \ src/open_r1/grpo_jsonl.py \ --use_vllm False \ --output_dir ${REPO_HOME}/checkpoints/rl/${EXP_NAME} \ --resume_from_checkpoint True \ --model_name_or_path $model_path \ --data_file_paths $data_paths \ --image_folders $image_folders \ --is_reward_customized_from_vlm_module $is_reward_customized_from_vlm_module \ --task_type $TASK_TYPE \ --per_device_train_batch_size 8 \ --gradient_accumulation_steps 2\ --gradient_checkpointing true \ --logging_steps 1 \ --num_train_epochs 1 \ --bf16 \ --attn_implementation flash_attention_2 \ --run_name ${EXP_NAME} \ --data_seed 42 \ --save_steps 100 \ --num_generations 8 \ --max_completion_length 2048 \ --reward_funcs accuracy format type length\ --beta 0.04 \ --report_to none \ --dataset-name this_is_not_used \ --deepspeed ${REPO_HOME}/src/open-r1-multimodal/local_scripts/zero3.json \ --freeze_vision_modules true echo "Training completed for ${EXP_NAME}" ================================================ FILE: setup.sh ================================================ # conda create -n vlm-r1 python=3.11 # conda activate vlm-r1 # Install the packages in open-r1-multimodal . cd src/open-r1-multimodal # We edit the grpo.py and grpo_trainer.py in open-r1 repo. # Install torch first (required for flash-attn) pip install torch>=2.5.1 torchvision # Install open-r1 package with dev dependencies pip install -e ".[dev]" # Additional modules pip install wandb==0.18.3 pip install tensorboardx pip install qwen_vl_utils pip install babel pip install python-Levenshtein pip install matplotlib pip install pycocotools pip install openai pip install httpx[socks] # Install flash-attn last (requires torch to be already installed) pip install flash-attn --no-build-isolation ================================================ FILE: src/eval/test_od_r1.py ================================================ import re import os import json import torch import random from tqdm import tqdm from pprint import pprint from qwen_vl_utils import process_vision_info from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor def extract_bbox_answer(content): pattern = r'```json(.*?)```' json_match = re.search(pattern, content, re.DOTALL) bbox_json = json_match.group(1).strip() if json_match else None if bbox_json: try: bbox = json.loads(bbox_json)[0]['bbox_2d'] return bbox, False except: return [0, 0, 0, 0], False else: return [0, 0, 0, 0], False def iou(box1, box2): inter_x1 = max(box1[0], box2[0]) inter_y1 = max(box1[1], box2[1]) inter_x2 = min(box1[2] - 1, box2[2] - 1) inter_y2 = min(box1[3] - 1, box2[3] - 1) if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter = (inter_x2 - inter_x1 + 1) * (inter_y2 - inter_y1 + 1) else: inter = 0 union = (box1[2] - box1[0]) * (box1[3] - box1[1]) + (box2[2] - box2[0]) * (box2[3] - box2[1]) - inter return float(inter) / union def load_model(model_path, device_map): #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map=device_map, ) # default processer processor = AutoProcessor.from_pretrained(model_path) return model, processor def eval_od_r1( model_path, test_datasets, data_root, image_root, question_template, output_dir, batch_size=32, sample_num=500, seed=42, device_map="cuda:0" ): random.seed(seed) model, processor = load_model(model_path, device_map) for ds in test_datasets: print(f"Processing {ds}...") ds_path = os.path.join(data_root, f"{ds}.json") data = json.load(open(ds_path, "r")) random.shuffle(data) data = data[:sample_num] messages = [] for x in data: image_path = os.path.join(image_root, x['image']) messages.append( [ { "role": "user", "content": [ { "type": "image", "image": f"file://{image_path}" }, { "type": "text", "text": question_template.format(Question=x['normal_caption']) } ] } ] ) all_outputs = [] # List to store all answers # Process data for i in tqdm(range(0, len(messages), batch_size)): batch_messages = messages[i:i + batch_size] # Preparation for inference text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] image_inputs, video_inputs = process_vision_info(batch_messages) inputs = processor( text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(device_map) # Inference: Generation of the output generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] batch_output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) all_outputs.extend(batch_output_text) final_output = [] correct_number = 0 for input_example, model_output in zip(data, all_outputs): original_output = model_output ground_truth = input_example['solution'] ground_truth_normalized = input_example['normalized_solution'] model_answer, normalized = extract_bbox_answer(original_output) # Count correct answers correct = 0 if model_answer is not None: iou_value = iou(model_answer, ground_truth_normalized if normalized else ground_truth) if iou_value > 0.5: correct = 1 correct_number += correct # Create a result dictionary for this example result = { "question": question_template.format(Question=input_example['normal_caption']), "ground_truth": ground_truth if not normalized else ground_truth_normalized, "model_output": original_output, "extracted_answer": model_answer, "correct": correct, "iou": iou_value } final_output.append(result) # Calculate and print accuracy accuracy = correct_number / len(data) * 100 print(f"\nAccuracy of {ds}: {accuracy:.2f}%") # Save results to a JSON file result_path = os.path.join(output_dir, f"{os.path.basename(model_path)}", f"{ds}_od_r1.json") os.makedirs(os.path.dirname(result_path), exist_ok=True) with open(result_path, "w") as f: json.dump({"accuracy": accuracy, "results": final_output}, f, indent=2) print(f"Results saved to {result_path}") print('-' * 100) if __name__ == "__main__": model_path = '' # Add the path to the model data_root = '' # Add the data root test_datasets = ['refcoco_val', 'refcocop_val', 'refcocog_val'] # modify the datasets image_root = '' # Add the image root output_dir = 'logs' # Add the output directory, default is logs device_map = 'cuda:0' # select the device, default is cuda:0 question_template = '{Question} First output the thinking process in tags and then output the final answer in tags. Output the final answer in JSON format.' # modify the question template which must contain {Question}, {Question} will be replaced by the caption eval_od_r1( model_path=model_path, data_root=data_root, test_datasets=test_datasets, image_root=image_root, question_template=question_template, output_dir=output_dir, device_map=device_map ) ================================================ FILE: src/eval/test_rec_baseline.py ================================================ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info import torch import json from tqdm import tqdm import re import os from pprint import pprint import random import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import argparse import warnings warnings.filterwarnings("ignore", category=UserWarning, module="transformers") def setup_distributed(): local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl") world_size = dist.get_world_size() rank = dist.get_rank() print(f"Process {rank}/{world_size} initialized on cuda:{local_rank}") return local_rank, world_size, rank local_rank, world_size, rank = setup_distributed() device = f"cuda:{local_rank}" steps = 100 MODEL_PATH=f"/data10/shz/project/LLaMA-Factory/saves/qwen2_5_vl-3b/full/sft/checkpoint-{steps}" OUTPUT_PATH="./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_sft_{STEPS}.json" # MODEL_PATH = "/data10/shz/ckpt/vlm-r1-related/Qwen2.5-VL-3B-Instruct" # OUTPUT_PATH = "./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_baseline_{STEPS}.json" BSZ=4 DATA_ROOT = "/data10/shz/dataset/rec/rec_jsons_processed" TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val'] IMAGE_ROOT = "/data10/shz/dataset/coco" # TEST_DATASETS = ['lisa_test'] # IMAGE_ROOT = "/data10/shz/dataset/lisa" #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map={"": local_rank}, ) # default processer processor = AutoProcessor.from_pretrained(MODEL_PATH) def extract_bbox_answer(content): bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]' # bbox_pattern = r'\[(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+)\]' bbox_match = re.search(bbox_pattern, content) if bbox_match: bbox = [float(bbox_match.group(1)), float(bbox_match.group(2)), float(bbox_match.group(3)), float(bbox_match.group(4))] return bbox return [0, 0, 0, 0] def iou(box1, box2): inter_x1 = max(box1[0], box2[0]) inter_y1 = max(box1[1], box2[1]) inter_x2 = min(box1[2]-1, box2[2]-1) inter_y2 = min(box1[3]-1, box2[3]-1) if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) else: inter = 0 union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter return float(inter)/union num_samples = 2000 for ds in TEST_DATASETS: if rank == 0: print(f"Processing {ds}...") ds_path = os.path.join(DATA_ROOT, f"{ds}.json") data = json.load(open(ds_path, "r")) random.seed(42) random.shuffle(data) data = data[:num_samples] # QUESTION_TEMPLATE = "{Question}" if steps > 0 else "{Question} Please provide the bounding box coordinate in JSON format." QUESTION_TEMPLATE = "{Question} Please provide the bounding box coordinate in JSON format." # Split data for distributed evaluation per_rank_data = len(data) // world_size start_idx = rank * per_rank_data end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data) rank_data = data[start_idx:end_idx] messages = [] for x in rank_data: image_path = os.path.join(IMAGE_ROOT, x['image']) message = [ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, { "role": "user", "content": [ { "type": "image", "image": f"file://{image_path}" }, { "type": "text", "text": QUESTION_TEMPLATE.format(Question=x['problem']) } ] }] messages.append(message) rank_outputs = [] # List to store answers for this rank all_outputs = [] # List to store all answers # Process data for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0): batch_messages = messages[i:i + BSZ] # Preparation for inference text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] image_inputs, video_inputs = process_vision_info(batch_messages) inputs = processor( text=text, images=image_inputs, videos=video_inputs, padding=True, padding_side="left", return_tensors="pt", ) inputs = inputs.to(device) # Inference: Generation of the output generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] batch_output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) rank_outputs.extend(batch_output_text) print(f"Rank {rank} has finished processing {len(rank_outputs)} examples") # Gather all outputs from all ranks all_outputs = [None] * len(data) rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)] gathered_results = [None] * world_size dist.all_gather_object(gathered_results, rank_results) assert gathered_results[-1][-1][0] == len(data) - 1 # The main process will collect all results if rank == 0: for results in gathered_results: for idx, output in results: assert idx < len(all_outputs) all_outputs[idx] = output assert all_outputs[-1] is not None final_output = [] correct_number = 0 for input_example, model_output in zip(data, all_outputs): original_output = model_output ground_truth = input_example['solution'] model_answer = extract_bbox_answer(original_output) # Count correct answers correct = 0 if model_answer is not None: if iou(model_answer, ground_truth) > 0.5: correct = 1 correct_number += correct # Create a result dictionary for this example result = { 'image': input_example['image'], 'question': input_example['problem'], 'ground_truth': ground_truth, 'model_output': original_output, 'extracted_answer': model_answer, 'correct': correct } final_output.append(result) # Calculate and print accuracy accuracy = correct_number / len(data) * 100 print(f"\nAccuracy of {ds}: {accuracy:.2f}%") # Save results to a JSON file output_path = OUTPUT_PATH.format(DATASET=ds, STEPS=steps) output_dir = os.path.dirname(output_path) if not os.path.exists(output_dir): os.makedirs(output_dir) with open(output_path, "w") as f: json.dump({ 'accuracy': accuracy, 'results': final_output }, f, indent=2) print(f"Results saved to {output_path}") print("-"*100) # Synchronize all processes dist.barrier() ================================================ FILE: src/eval/test_rec_r1.py ================================================ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info import torch import json from tqdm import tqdm import re import os from pprint import pprint import random import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import argparse import warnings warnings.filterwarnings("ignore", category=UserWarning, module="transformers") def setup_distributed(): local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl") world_size = dist.get_world_size() rank = dist.get_rank() return local_rank, world_size, rank local_rank, world_size, rank = setup_distributed() device = f"cuda:{local_rank}" print(f"Process {rank} using {device}") main_rank = 0 steps = 100 if rank == main_rank: print("Steps: ", steps) RUN_NAME = "Qwen2.5-VL-3B-Instruct-rec" MODEL_PATH=f"/training/shz/project/vlm-r1/VLM-R1/checkpoints/rl/{RUN_NAME}/checkpoint-{steps}" OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json" BSZ=2 DATA_ROOT = "/training/shz/dataset/vlm-r1/rec_jsons_processed" # TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val'] # IMAGE_ROOT = "/training/shz/dataset/coco" TEST_DATASETS = ['lisa_test'] IMAGE_ROOT = "/training/shz/dataset/lisa" #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map={"": local_rank}, ) # default processer processor = AutoProcessor.from_pretrained(MODEL_PATH) def extract_bbox_answer(content): # Try to find the bbox within tags, if can not find, return [0, 0, 0, 0] answer_tag_pattern = r'(.*?)' bbox_pattern = r'\{.*\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]\s*.*\}' content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL) if content_answer_match: content_answer = content_answer_match.group(1).strip() bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL) if bbox_match: bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))] return bbox return [0, 0, 0, 0] def iou(box1, box2): inter_x1 = max(box1[0], box2[0]) inter_y1 = max(box1[1], box2[1]) inter_x2 = min(box1[2]-1, box2[2]-1) inter_y2 = min(box1[3]-1, box2[3]-1) if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) else: inter = 0 union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter return float(inter)/union num_samples = 2000 for ds in TEST_DATASETS: if rank == 0: print(f"Processing {ds}...") ds_path = os.path.join(DATA_ROOT, f"{ds}.json") data = json.load(open(ds_path, "r")) random.seed(42) random.shuffle(data) data = data[:num_samples] QUESTION_TEMPLATE = "{Question} First output the thinking process in tags and then output the final answer in tags. Output the final answer in JSON format." # Split data for distributed evaluation per_rank_data = len(data) // world_size start_idx = rank * per_rank_data end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data) rank_data = data[start_idx:end_idx] messages = [] for x in rank_data: image_path = os.path.join(IMAGE_ROOT, x['image']) message = [ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, { "role": "user", "content": [ { "type": "image", "image": f"file://{image_path}" }, { "type": "text", "text": QUESTION_TEMPLATE.format(Question=x['problem']) } ] }] messages.append(message) rank_outputs = [] # List to store answers for this rank all_outputs = [] # List to store all answers # Process data for i in tqdm(range(0, len(messages), BSZ), disable=rank != main_rank): batch_messages = messages[i:i + BSZ] # Preparation for inference text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] image_inputs, video_inputs = process_vision_info(batch_messages) inputs = processor( text=text, images=image_inputs, videos=video_inputs, padding=True, padding_side="left", return_tensors="pt", ) inputs = inputs.to(device) # Inference: Generation of the output generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] batch_output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) rank_outputs.extend(batch_output_text) print(f"Rank {rank} has finished processing {len(rank_outputs)} examples") # Gather all outputs from all ranks all_outputs = [None] * len(data) rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)] gathered_results = [None] * world_size dist.all_gather_object(gathered_results, rank_results) assert gathered_results[-1][-1][0] == len(data) - 1 # The main process will collect all results if rank == main_rank: for results in gathered_results: for idx, output in results: assert idx < len(all_outputs) all_outputs[idx] = output assert all_outputs[-1] is not None final_output = [] correct_number = 0 for input_example, model_output in zip(data, all_outputs): original_output = model_output ground_truth = input_example['solution'] model_answer = extract_bbox_answer(original_output) # Count correct answers correct = 0 if model_answer is not None: if iou(model_answer, ground_truth) > 0.5: correct = 1 correct_number += correct # Create a result dictionary for this example result = { 'image': input_example['image'], 'question': input_example['problem'], 'ground_truth': ground_truth, 'model_output': original_output, 'extracted_answer': model_answer, 'correct': correct } final_output.append(result) # Calculate and print accuracy accuracy = correct_number / len(data) * 100 print(f"\nAccuracy of {ds}: {accuracy:.2f}%") # Save results to a JSON file output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps) output_dir = os.path.dirname(output_path) if not os.path.exists(output_dir): os.makedirs(output_dir) with open(output_path, "w") as f: json.dump({ 'accuracy': accuracy, 'results': final_output }, f, indent=2) print(f"Results saved to {output_path}") print("-"*100) # Synchronize all processes dist.barrier() ================================================ FILE: src/eval/test_rec_r1_internvl.py ================================================ import torch import json from tqdm import tqdm import re import os from pprint import pprint import random from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM from open_r1.vlm_modules.internvl_module import InvernVLModule import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import warnings warnings.filterwarnings("ignore", category=UserWarning, module="transformers") def setup_distributed(): local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl") world_size = dist.get_world_size() rank = dist.get_rank() return local_rank, world_size, rank local_rank, world_size, rank = setup_distributed() device = f"cuda:{local_rank}" print(f"Process {rank} using {device}") main_rank = 0 steps = 300 if rank == main_rank: print("Steps: ", steps) RUN_NAME = "InternVL2_5-4B_MPO-rec" MODEL_PATH=f"/training/shz/project/vlm-r1/VLM-R1/checkpoints/rl/{RUN_NAME}/checkpoint-{steps}" OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json" BSZ=4 DATA_ROOT = "/training/shz/dataset/vlm-r1/rec_jsons_internvl" # TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val'] # IMAGE_ROOT = "/training/shz/dataset/coco" TEST_DATASETS = ['lisa_test'] IMAGE_ROOT = "/training/shz/dataset/lisa" random.seed(42) vlm_module = InvernVLModule() model = vlm_module.get_model_class(MODEL_PATH, {}).from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, device_map={"": local_rank}, trust_remote_code=True, use_flash_attn=True, ) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) tokenizer.pad_token_id = tokenizer.eos_token_id model.generation_config.pad_token_id = tokenizer.pad_token_id vlm_module.post_model_init(model, tokenizer) def extract_bbox_answer(content): # Try to find the bbox within tags, if can not find, return [0, 0, 0, 0] answer_tag_pattern = r'(.*?)' bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]' content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL) if content_answer_match: content_answer = content_answer_match.group(1).strip() bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL) if bbox_match: bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))] return bbox return [0, 0, 0, 0] def iou(box1, box2): inter_x1 = max(box1[0], box2[0]) inter_y1 = max(box1[1], box2[1]) inter_x2 = min(box1[2]-1, box2[2]-1) inter_y2 = min(box1[3]-1, box2[3]-1) if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) else: inter = 0 union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter return float(inter)/union from PIL import Image def process_vision_info(batch_messages): images = [] for msg in batch_messages: image_path = msg[0]['content'][0]['image'].replace("file://", "") image = Image.open(image_path) images.append(image) return images sample_num = 2000 tokenizer.max_anyres_num = 12 for ds in TEST_DATASETS: if rank == main_rank: print(f"Processing {ds}...") ds_path = os.path.join(DATA_ROOT, f"{ds}.json") data = json.load(open(ds_path, "r")) random.seed(42) random.shuffle(data) data = data[:sample_num] QUESTION_TEMPLATE = "{Question} First output the thinking process in tags and then output the final answer in tags." # Split data for distributed evaluation per_rank_data = len(data) // world_size start_idx = rank * per_rank_data end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data) rank_data = data[start_idx:end_idx] messages = [] for x in rank_data: image_path = os.path.join(IMAGE_ROOT, x['image']) message = [ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, { "role": "user", "content": [ { "type": "image", "image": f"file://{image_path}" }, { "type": "text", "text": QUESTION_TEMPLATE.format(Question=x['problem']) } ] }] messages.append(message) rank_outputs = [] # List to store answers for this rank all_outputs = [] # List to store all answers # Process data for i in tqdm(range(0, len(messages), BSZ), disable=rank != main_rank): batch_messages = messages[i:i + BSZ] prompts = vlm_module.prepare_prompt(None, [{"prompt": msg} for msg in batch_messages]) images = process_vision_info(batch_messages) model_inputs = vlm_module.prepare_model_inputs(tokenizer, prompts, images) model_inputs['pixel_values'] = model_inputs['pixel_values'].to(torch.bfloat16) model_inputs = model_inputs.to(device) outputs = model.generate(**{k:v for k,v in model_inputs.items() if k not in vlm_module.get_non_generate_params()}, max_new_tokens=1024, do_sample=False, pad_token_id=tokenizer.eos_token_id) batch_output_text = tokenizer.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False ) rank_outputs.extend(batch_output_text) print(f"Rank {rank} has finished processing {len(rank_outputs)} examples") # Gather all outputs from all ranks all_outputs = [None] * len(data) rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)] gathered_results = [None] * world_size dist.all_gather_object(gathered_results, rank_results) assert gathered_results[-1][-1][0] == len(data) - 1 # The main process will collect all results if rank == main_rank: for results in gathered_results: for idx, output in results: assert idx < len(all_outputs) all_outputs[idx] = output assert all_outputs[-1] is not None final_output = [] correct_number = 0 for input_example, model_output in zip(data, all_outputs): original_output = model_output ground_truth = input_example['solution'] model_answer = extract_bbox_answer(original_output) # Count correct answers correct = 0 if model_answer is not None and iou(model_answer, ground_truth) > 0.5: correct = 1 correct_number += correct # Create a result dictionary for this example result = { 'image': input_example['image'], 'question': input_example['problem'], 'ground_truth': ground_truth, 'model_output': original_output, 'extracted_answer': model_answer, 'correct': correct } final_output.append(result) # Calculate and print accuracy accuracy = correct_number / len(data) * 100 print(f"\nAccuracy of {ds}: {accuracy:.2f}%") # Save results to a JSON file output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps) output_dir = os.path.dirname(output_path) if not os.path.exists(output_dir): os.makedirs(output_dir) with open(output_path, "w") as f: json.dump({ 'accuracy': accuracy, 'results': final_output }, f, indent=4) print(f"Results saved to {output_path}") print("-"*100) # Synchronize all processes dist.barrier() ================================================ FILE: src/open-r1-multimodal/.gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # UV # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. #uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ # PyPI configuration file .pypirc # Temp folders data/ wandb/ scripts/ checkpoints/ .vscode/ ================================================ FILE: src/open-r1-multimodal/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: src/open-r1-multimodal/Makefile ================================================ .PHONY: style quality # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src check_dirs := src style: black --line-length 119 --target-version py310 $(check_dirs) setup.py isort $(check_dirs) setup.py quality: black --check --line-length 119 --target-version py310 $(check_dirs) setup.py isort --check-only $(check_dirs) setup.py flake8 --max-line-length 119 $(check_dirs) setup.py # Evaluation evaluate: ================================================ FILE: src/open-r1-multimodal/configs/ddp.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false distributed_type: MULTI_GPU downcast_bf16: 'no' gpu_ids: all machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ================================================ FILE: src/open-r1-multimodal/configs/zero2.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false deepspeed_config: deepspeed_multinode_launcher: standard offload_optimizer_device: none offload_param_device: none zero3_init_flag: false zero_stage: 2 distributed_type: DEEPSPEED downcast_bf16: 'no' machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ================================================ FILE: src/open-r1-multimodal/configs/zero3.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false deepspeed_config: deepspeed_multinode_launcher: standard offload_optimizer_device: cpu offload_param_device: cpu zero3_init_flag: true zero3_save_16bit_model: true zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: 'no' machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 2 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ================================================ FILE: src/open-r1-multimodal/local_scripts/zero2.json ================================================ { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "optimizer": { "type": "AdamW", "params": { "lr": "auto", "betas": "auto", "eps": "auto", "weight_decay": "auto" } }, "zero_optimization": { "stage": 2, "offload_optimizer": { "device": "none", "pin_memory": true }, "allgather_partitions": true, "allgather_bucket_size": 2e8, "overlap_comm": false, "reduce_scatter": true, "reduce_bucket_size": 2e8, "contiguous_gradients": true }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "steps_per_print": 100, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false } ================================================ FILE: src/open-r1-multimodal/local_scripts/zero3.json ================================================ { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": true }, "offload_param": { "device": "cpu", "pin_memory": true }, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": true }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "steps_per_print": 100, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false } ================================================ FILE: src/open-r1-multimodal/local_scripts/zero3.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false deepspeed_config: deepspeed_multinode_launcher: standard offload_optimizer_device: none offload_param_device: none zero3_init_flag: true zero3_save_16bit_model: true zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: 'no' machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ================================================ FILE: src/open-r1-multimodal/local_scripts/zero3_offload.json ================================================ { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "optimizer": { "type": "AdamW", "params": { "lr": "auto", "betas": "auto", "eps": "auto", "weight_decay": "auto" } }, "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": true }, "offload_param": { "device": "cpu", "pin_memory": true }, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "gather_16bit_weights_on_model_save": true }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "steps_per_print": 1e5, "wall_clock_breakdown": false } ================================================ FILE: src/open-r1-multimodal/local_scripts/zero_stage2_config.json ================================================ { "zero_optimization": { "stage": 2, "allgather_partitions": true, "allgather_bucket_size": 1e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 1e8, "contiguous_gradients": true }, "fp16": { "enabled": "auto", "auto_cast": true, "loss_scale": 0, "initial_scale_power": 32, "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "steps_per_print": 2000, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false } ================================================ FILE: src/open-r1-multimodal/setup.cfg ================================================ [isort] default_section = FIRSTPARTY ensure_newline_before_comments = True force_grid_wrap = 0 include_trailing_comma = True known_first_party = open_r1 known_third_party = transformers datasets fugashi git h5py matplotlib nltk numpy packaging pandas psutil pytest rouge_score sacrebleu seqeval sklearn streamlit torch tqdm line_length = 119 lines_after_imports = 2 multi_line_output = 3 use_parentheses = True [flake8] ignore = E203, E501, E741, W503, W605 max-line-length = 119 per-file-ignores = # imported but unused __init__.py: F401 [tool:pytest] doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS ================================================ FILE: src/open-r1-multimodal/setup.py ================================================ # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py import re import shutil from pathlib import Path from setuptools import find_packages, setup # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466 stale_egg_info = Path(__file__).parent / "open_r1.egg-info" if stale_egg_info.exists(): print( ( "Warning: {} exists.\n\n" "If you recently updated open_r1, this is expected,\n" "but it may prevent open_r1 from installing in editable mode.\n\n" "This directory is automatically generated by Python's packaging tools.\n" "I will remove it now.\n\n" "See https://github.com/pypa/pip/issues/5466 for details.\n" ).format(stale_egg_info) ) shutil.rmtree(stale_egg_info) # IMPORTANT: all dependencies should be listed here with their version requirements, if any. # * If a dependency is fast-moving (e.g. transformers), pin to the exact version _deps = [ "accelerate>=1.2.1", "bitsandbytes>=0.43.0", "black>=24.4.2", "datasets>=3.2.0", "deepspeed==0.15.4", "distilabel[vllm,ray,openai]>=1.5.2", "einops>=0.8.0", "flake8>=6.0.0", "hf_transfer>=0.1.4", "huggingface-hub[cli]>=0.19.2,<1.0", "isort>=5.12.0", "liger_kernel==0.5.2", # "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]", "math-verify", # Used for math verification in grpo "packaging>=23.0", "parameterized>=0.9.0", "pytest", "safetensors>=0.3.3", "sentencepiece>=0.1.99", "torch>=2.5.1", "transformers==4.49.0", "trl @ git+https://github.com/huggingface/trl.git@main", "vllm==0.6.6.post1", "wandb>=0.19.1", "pillow", ] # this is a lookup table with items like: # # tokenizers: "tokenizers==0.9.4" # packaging: "packaging" # # some of the values are versioned whereas others aren't. deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)} def deps_list(*pkgs): return [deps[pkg] for pkg in pkgs] extras = {} extras["tests"] = deps_list("pytest", "parameterized") extras["torch"] = deps_list("torch") extras["quality"] = deps_list("black", "isort", "flake8") # extras["eval"] = deps_list("lighteval", "math-verify") extras["eval"] = deps_list("math-verify") extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] # core dependencies shared across the whole project - keep this to a bare minimum :) install_requires = [ deps["accelerate"], deps["bitsandbytes"], deps["einops"], deps["datasets"], deps["deepspeed"], deps["hf_transfer"], deps["huggingface-hub"], deps["liger_kernel"], deps["packaging"], # utilities from PyPA to e.g., compare versions deps["safetensors"], deps["sentencepiece"], deps["transformers"], deps["trl"], ] setup( name="open-r1", version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) author="The Hugging Face team (past and future)", author_email="lewis@huggingface.co", description="Open R1", # long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", keywords="llm inference-time compute reasoning", license="Apache", url="https://github.com/huggingface/open-r1", package_dir={"": "src"}, packages=find_packages("src"), zip_safe=False, extras_require=extras, python_requires=">=3.10.9", install_requires=install_requires, classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) ================================================ FILE: src/open-r1-multimodal/src/open_r1/__init__.py ================================================ ================================================ FILE: src/open-r1-multimodal/src/open_r1/configs.py ================================================ # coding=utf-8 # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field from typing import Optional import trl # TODO: add the shared options with a mixin to reduce code duplication @dataclass class GRPOConfig(trl.GRPOConfig): """ args for callbacks, benchmarks etc """ benchmarks: list[str] = field( default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."} ) callbacks: list[str] = field( default_factory=lambda: [], metadata={"help": "The callbacks to run during training."} ) system_prompt: Optional[str] = field( default=None, metadata={"help": "The optional system prompt to use for benchmarking."} ) hub_model_revision: Optional[str] = field( default="main", metadata={"help": "The Hub model branch to push the model to."} ) overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) wandb_entity: Optional[str] = field( default=None, metadata={"help": ("The entity to store runs under.")}, ) wandb_project: Optional[str] = field( default=None, metadata={"help": ("The project to store runs under.")}, ) @dataclass class SFTConfig(trl.SFTConfig): """ args for callbacks, benchmarks etc """ benchmarks: list[str] = field( default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."} ) callbacks: list[str] = field( default_factory=lambda: [], metadata={"help": "The callbacks to run during training."} ) system_prompt: Optional[str] = field( default=None, metadata={"help": "The optional system prompt to use for benchmarking."}, ) hub_model_revision: Optional[str] = field( default="main", metadata={"help": "The Hub model branch to push the model to."}, ) overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) wandb_entity: Optional[str] = field( default=None, metadata={"help": ("The entity to store runs under.")}, ) wandb_project: Optional[str] = field( default=None, metadata={"help": ("The project to store runs under.")}, ) ================================================ FILE: src/open-r1-multimodal/src/open_r1/evaluate.py ================================================ # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Custom evaluation tasks for LightEval.""" from lighteval.metrics.dynamic_metrics import ( ExprExtractionConfig, LatexExtractionConfig, multilingual_extractive_match_metric, ) from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc from lighteval.utils.language import Language metric = multilingual_extractive_match_metric( language=Language.ENGLISH, fallback_mode="first_match", precision=5, gold_extraction_target=(LatexExtractionConfig(),), pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), aggregation_function=max, ) def prompt_fn(line, task_name: str = None): """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically""" return Doc( task_name=task_name, query=line["problem"], choices=[line["solution"]], gold_index=0, ) # Define tasks aime24 = LightevalTaskConfig( name="aime24", suite=["custom"], prompt_function=prompt_fn, hf_repo="HuggingFaceH4/aime_2024", hf_subset="default", hf_avail_splits=["train"], evaluation_splits=["train"], few_shots_split=None, few_shots_select=None, generation_size=32768, metric=[metric], version=1, ) math_500 = LightevalTaskConfig( name="math_500", suite=["custom"], prompt_function=prompt_fn, hf_repo="HuggingFaceH4/MATH-500", hf_subset="default", hf_avail_splits=["test"], evaluation_splits=["test"], few_shots_split=None, few_shots_select=None, generation_size=32768, metric=[metric], version=1, ) # Add tasks to the table TASKS_TABLE = [] TASKS_TABLE.append(aime24) TASKS_TABLE.append(math_500) # MODULE LOGIC if __name__ == "__main__": print([t["name"] for t in TASKS_TABLE]) print(len(TASKS_TABLE)) ================================================ FILE: src/open-r1-multimodal/src/open_r1/generate.py ================================================ # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional from distilabel.llms import OpenAILLM from distilabel.pipeline import Pipeline from distilabel.steps.tasks import TextGeneration def build_distilabel_pipeline( model: str, base_url: str = "http://localhost:8000/v1", prompt_column: Optional[str] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, max_new_tokens: int = 8192, num_generations: int = 1, ) -> Pipeline: generation_kwargs = {"max_new_tokens": max_new_tokens} if temperature is not None: generation_kwargs["temperature"] = temperature if top_p is not None: generation_kwargs["top_p"] = top_p with Pipeline().ray() as pipeline: TextGeneration( llm=OpenAILLM( base_url=base_url, api_key="something", model=model, # thinking can take some time... timeout=10 * 60, generation_kwargs=generation_kwargs, ), input_mappings={"instruction": prompt_column} if prompt_column is not None else {}, input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion num_generations=num_generations, ) return pipeline if __name__ == "__main__": import argparse from datasets import load_dataset parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1") parser.add_argument( "--hf-dataset", type=str, required=True, help="HuggingFace dataset to load", ) parser.add_argument( "--hf-dataset-config", type=str, required=False, help="Dataset config to use", ) parser.add_argument( "--hf-dataset-split", type=str, default="train", help="Dataset split to use", ) parser.add_argument("--prompt-column", type=str, default="prompt") parser.add_argument( "--model", type=str, required=True, help="Model name to use for generation", ) parser.add_argument( "--vllm-server-url", type=str, default="http://localhost:8000/v1", help="URL of the vLLM server", ) parser.add_argument( "--temperature", type=float, help="Temperature for generation", ) parser.add_argument( "--top-p", type=float, help="Top-p value for generation", ) parser.add_argument( "--max-new-tokens", type=int, default=8192, help="Maximum number of new tokens to generate", ) parser.add_argument( "--num-generations", type=int, default=1, help="Number of generations per problem", ) parser.add_argument( "--hf-output-dataset", type=str, required=False, help="HuggingFace repo to push results to", ) parser.add_argument( "--private", action="store_true", help="Whether to make the output dataset private when pushing to HF Hub", ) args = parser.parse_args() print("\nRunning with arguments:") for arg, value in vars(args).items(): print(f" {arg}: {value}") print() print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...") dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split) print("Dataset loaded!") pipeline = build_distilabel_pipeline( model=args.model, base_url=args.vllm_server_url, prompt_column=args.prompt_column, temperature=args.temperature, top_p=args.top_p, max_new_tokens=args.max_new_tokens, num_generations=args.num_generations, ) print("Running generation pipeline...") distiset = pipeline.run(dataset=dataset, use_cache=False) print("Generation pipeline finished!") if args.hf_output_dataset: print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...") distiset.push_to_hub(args.hf_output_dataset, private=args.private) print("Dataset pushed!") ================================================ FILE: src/open-r1-multimodal/src/open_r1/grpo_jsonl.py ================================================ # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import pathlib from datetime import datetime from dataclasses import dataclass, field from typing import Optional from babel.numbers import parse_decimal from utils.math import compute_score from datasets import load_dataset, load_from_disk from transformers import Qwen2VLForConditionalGeneration from math_verify import parse, verify from trainer import VLMGRPOTrainer, GRPOConfig # from trainer import VLMGRPOTrainer, GRPOConfig from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config import PIL from Levenshtein import ratio from utils.pycocotools.coco import COCO from utils.pycocotools.cocoeval import COCOeval import json import math from json_repair import repair_json from vlm_modules import * from typing import Tuple from transformers.utils import logging from transformers import AutoProcessor, AutoTokenizer from openai import OpenAI logger = logging.get_logger(__name__) client = OpenAI( api_key=os.getenv("OPENAI_API_KEY", ""), # Must be set via environment variable base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") ) from qwen2_5vl_monkey_patch import monkey_patch_qwen2_5vl_flash_attn, monkey_patch_qwen2_5vl_forward, monkey_patch_torch_load monkey_patch_qwen2_5vl_flash_attn() monkey_patch_torch_load() tokenizer = None def initialize_tokenizer(model_path): global tokenizer if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model_path,local_files_only=True) print(f"Is Fast Tokenizer? {tokenizer.is_fast}") return tokenizer @dataclass class GRPOScriptArguments(ScriptArguments): """ Script arguments for the GRPO training script. """ data_file_paths: str = field( default=None, metadata={"help": "Paths to data files, separated by ':'"}, ) image_folders: str = field( default=None, metadata={"help": "Paths to image folders, separated by ':'"}, ) arrow_cache_dir: str = field( default=None, metadata={"help": "Path to arrow cache directory"}, ) val_split_ratio: float = field( default=0.0, metadata={"help": "Ratio of validation split, default 0.0"}, ) reward_funcs: list[str] = field( default_factory=lambda: ["accuracy", "format"], metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}, ) max_pixels: Optional[int] = field( default=12845056, metadata={"help": "Maximum number of pixels for the image (for QwenVL)"}, ) min_pixels: Optional[int] = field( default=3136, metadata={"help": "Minimum number of pixels for the image (for QwenVL)"}, ) max_anyres_num: Optional[int] = field( default=12, metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"}, ) reward_method: Optional[str] = field( default=None, metadata={ "help": "Choose reward method: 'default', 'mcp', ..." }, ) task_type: Optional[str] = field( default=None, metadata={"help": "Choose task type: 'default', 'gui', ..."}, ) is_reward_customized_from_vlm_module: bool = field( default=False, metadata={"help": "Whether to use a customized reward from vlm module"}, ) def extract_choice(text): # 1. Clean and normalize text text = text.upper() # Convert to uppercase text = re.sub(r'\s+', ' ', text) # Normalize spaces # 2. Choice should not have uppercase letters before or after choices = re.findall(r'(? len(text) * 0.7: # In last 30% of text choice_scores[choice] += 2 # Add points if followed by punctuation if pos < len(text) - 1 and text[pos+1] in '。.!!,,': choice_scores[choice] += 1 # Return highest scoring choice return max(choice_scores.items(), key=lambda x: x[1])[0] def evaluate_answer_similarity(student_answer, ground_truth): """Use llm to evaluate answer similarity.""" try: response = client.chat.completions.create( model="qwen2.5:7b", messages=[ { "role": "user", "content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed." }, { "role": "user", "content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:" } ], temperature=0 ) result = response.choices[0].message.content.strip() return float(result) except Exception as e: print(f"Error in GPT evaluation: {e}") # If API call fails, fall back to simple text matching return 1.0 if student_answer ==ground_truth else 0.0 def llm_reward(content, sol, **kwargs): # Extract answer from content if it has think/answer tags sol_match = re.search(r'(.*?)', sol) ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() # Extract answer from content if it has think/answer tags content_matches = re.findall(r'(.*?)', content, re.DOTALL) student_answer = content_matches[-1].strip() if content_matches else content.strip() return evaluate_answer_similarity(student_answer, ground_truth) def mcq_reward(content, sol, **kwargs): # For multiple choice, extract and compare choices sol_match = re.search(r'(.*?)', sol) ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() has_choices = extract_choice(ground_truth) correct_choice = has_choices.upper() if has_choices else sol.strip() # Extract answer from content if it has think/answer tags content_match = re.search(r'(.*?)', content, re.DOTALL) student_answer = content_match.group(1).strip() if content_match else content.strip() student_choice = extract_choice(student_answer) if student_choice: reward = 1.0 if student_choice == correct_choice else 0.0 else: reward = 0.0 return reward def yes_no_reward(content, sol, **kwargs): content = content.lower() sol = sol.lower() # Extract answer from solution if it has think/answer tags sol_match = re.search(r'(.*?)', sol) ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() # Extract answer from content if it has think/answer tags content_match = re.search(r'(.*?)', content, re.DOTALL) student_answer = content_match.group(1).strip() if content_match else content.strip() ground_yes_no = re.search(r'(yes|no)', ground_truth) ground_yes_no = ground_yes_no.group(1) if ground_yes_no else '' student_yes_no = re.search(r'(yes|no)', student_answer) student_yes_no = student_yes_no.group(1) if student_yes_no else '' reward = 1.0 if ground_yes_no == student_yes_no else 0.0 return reward # score_type: 0 for mAP, 1 for mAP 50 def calculate_map(pred_bbox_list, gt_bbox_list, score_type=0): # Calculate mAP # Initialize COCO object for ground truth gt_json = {"annotations": [], "images": [], "categories": []} gt_json["images"] = [{ "id": 0, "width": 2048, "height": 2048, "file_name": "image_0.jpg" }] gt_json["categories"] = [] cats2id = {} cat_count = 0 for idx, gt_bbox in enumerate(gt_bbox_list): if gt_bbox["label"] not in cats2id: cats2id[gt_bbox["label"]] = cat_count gt_json["categories"].append({ "id": cat_count, "name": gt_bbox["label"] }) cat_count += 1 gt_json["annotations"].append({ "id": idx+1, "image_id": 0, "category_id": cats2id[gt_bbox["label"]], "bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]], "area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]), "iscrowd": 0 }) coco_gt = COCO(gt_json) dt_json = [] for idx, pred_bbox in enumerate(pred_bbox_list): try: dt_json.append({ "image_id": 0, "category_id": cats2id[pred_bbox["label"]], "bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]], "score": 1.0, "area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]) }) except: pass if len(dt_json) == 0: return 0.0 coco_dt = coco_gt.loadRes(dt_json) coco_eval = COCOeval(coco_gt, coco_dt, "bbox") coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() return coco_eval.stats[score_type] def map_reward(content, sol, length_reward=False, score_type=0, **kwargs): """ Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes. Args: content (str): String containing predicted bounding boxes in JSON format sol (str): String containing ground truth bounding boxes in JSON format length_reward (bool, optional): Whether to include length penalty in reward calculation. Defaults to False. score_type (int, optional): Type of COCO evaluation metric to use. Defaults to 0 (mAP). **kwargs: Additional keyword arguments Returns: float: mAP reward score between 0 and 1. If length_reward is True, the score is multiplied by a length penalty factor. """ # Extract JSON content between ```json tags pattern = r'```json(.*?)```' json_match = re.findall(pattern, sol, re.DOTALL) bbox_json = json_match[-1].strip() if json_match else None # Parse ground truth JSON to get bbox list gt_bbox_list = [] if bbox_json: bbox_data = json.loads(bbox_json) gt_bbox_list = [item for item in bbox_data] # Parse predicted JSON to get bbox list pred_bbox_list = [] json_match = re.findall(pattern, content, re.DOTALL) if json_match: try: bbox_data = json.loads(json_match[-1].strip()) pred_bbox_list = [item for item in bbox_data] except: # Return empty list if JSON parsing fails pred_bbox_list = [] # Calculate mAP if both prediction and ground truth exist if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0: bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list, score_type=score_type) elif len(pred_bbox_list) == 0 and len(gt_bbox_list) == 0: bbox_reward = 1.0 else: bbox_reward = 0.0 if length_reward: # Calculate length penalty based on ratio of ground truth to predicted bounding boxes gt_length = len(gt_bbox_list) pred_length = len(pred_bbox_list) # Full score if prediction has fewer boxes than ground truth, otherwise penalize proportionally length_score = 1.0 if gt_length >= pred_length else gt_length/pred_length return bbox_reward * length_score else: return bbox_reward def od_reward(content, sol, score_type=0, **kwargs): """ Calculate reward for object detection task by comparing predicted and ground truth answers. Args: content (str): Model's predicted answer containing bounding box annotations sol (str): Ground truth answer containing bounding box annotations score_type (int): Type of COCO evaluation metric to use (default: 0 for mAP) **kwargs: Additional keyword arguments Returns: float: Reward score between 0 and 1 based on mAP between predicted and ground truth boxes """ # Pattern to extract content between tags match_pattern = r'(.*?)' # Extract ground truth answer sol_match = re.search(match_pattern, sol, re.DOTALL) ground_truth = sol_match.group(1).strip() if sol_match else None # Extract predicted answer (using last match if multiple) content_match = re.findall(match_pattern, content, re.DOTALL) student_answer = content_match[-1].strip() if content_match else None # Return 0 if no prediction if student_answer is None: return 0.0 # Return 1 if both prediction and ground truth are None elif ground_truth == "None" and student_answer == "None": return 1.0 # Otherwise calculate mAP between prediction and ground truth else: return map_reward(student_answer, ground_truth, score_type=score_type) def odLength_reward(content, sol, **kwargs): """ Calculate reward for object detection task with length penalty. Args: content (str): Model's predicted answer containing bounding box annotations sol (str): Ground truth answer containing bounding box annotations **kwargs: Additional keyword arguments Returns: float: Reward score between 0 and 1 based on mAP and length penalty """ # Pattern to extract content between tags match_pattern = r'(.*?)' # Extract ground truth answer sol_match = re.search(match_pattern, sol, re.DOTALL) ground_truth = sol_match.group(1).strip() if sol_match else None # Extract predicted answer (using last match if multiple) content_match = re.findall(match_pattern, content, re.DOTALL) student_answer = content_match[-1].strip() if content_match else None # Return 0 if no prediction if student_answer is None: return 0.0 # Return 1 if both prediction and ground truth are None elif ground_truth == "None" and student_answer == "None": return 1.0 # Calculate mAP with length penalty else: bbox_reward = map_reward(student_answer, ground_truth, length_reward=True, score_type=0) return bbox_reward def iou(box1, box2): inter_x1 = max(box1[0], box2[0]) inter_y1 = max(box1[1], box2[1]) inter_x2 = min(box1[2]-1, box2[2]-1) inter_y2 = min(box1[3]-1, box2[3]-1) if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) else: inter = 0 union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter return float(inter)/union def detection_score(content, sol, iou_threshold=0.5, alpha=0.7, beta=0.0, gamma=0.3): pattern = r'```json(.*?)```' json_match = re.search(pattern, clean_text(content), re.DOTALL) content_bbox_json = json_match.group(1).strip() if json_match else None if content_bbox_json: try: bbox_data = json.loads(content_bbox_json) pred_boxes = [item for item in bbox_data] except: pred_boxes = [] else: pred_boxes = [] pattern = r'```json(.*?)```' json_match = re.search(pattern, clean_text(sol), re.DOTALL) sol_bbox_json = json_match.group(1).strip() if json_match else None if sol_bbox_json: bbox_data = json.loads(sol_bbox_json) gt_boxes = [item for item in bbox_data] else: gt_boxes = [] """ Calculate the comprehensive score for object detection Parameters: pred_boxes: List of predicted boxes, each element is in the format {"bbox_2d": [x1, y1, x2, y2], "label": "category name"} gt_boxes: List of ground truth boxes, each element is in the format {"bbox_2d": [x1, y1, x2, y2], "label": "category name"} iou_threshold: IoU threshold, default is 0.5 alpha: Position accuracy weight, default is 0.7 beta: Label accuracy weight, default is 0.0 gamma: Completeness weight (penalty for missed/false detections), default is 0.3 Returns: Comprehensive score, ranging from [0.0, 1.0] """ # Handle edge cases if len(gt_boxes) == 0: return 1.0 if not pred_boxes else 0.0 if len(pred_boxes) == 0: return 0.0 # Initialize matching results matches = [] # Store matched pairs of predicted and ground truth boxes unmatched_preds = list(range(len(pred_boxes))) # Indices of unmatched predicted boxes unmatched_gts = list(range(len(gt_boxes))) # Indices of unmatched ground truth boxes # Calculate IoU matrix between all predicted and ground truth boxes iou_matrix = [] for pred_idx, pred_box in enumerate(pred_boxes): iou_row = [] for gt_idx, gt_box in enumerate(gt_boxes): try: curr_iou = iou(pred_box["bbox_2d"], gt_box["bbox_2d"]) except: curr_iou = 0.0 iou_row.append(curr_iou) iou_matrix.append(iou_row) # Greedy matching: find the best match for each predicted box while unmatched_preds and unmatched_gts: # Find the maximum IoU max_iou = -1 max_pred_idx = -1 max_gt_idx = -1 for pred_idx in unmatched_preds: for gt_idx in unmatched_gts: curr_iou = iou_matrix[pred_idx][gt_idx] if curr_iou > max_iou: max_iou = curr_iou max_pred_idx = pred_idx max_gt_idx = gt_idx # Stop matching if the maximum IoU is below the threshold if max_iou < iou_threshold: break # Record matching results try: pred_label = pred_boxes[max_pred_idx]["label"].lower() except: pred_box = "" try: gt_label = gt_boxes[max_gt_idx]["label"].lower() except: gt_label = "" label_correct = (pred_label == gt_label) if label_correct: matches.append({ "pred_idx": max_pred_idx, "gt_idx": max_gt_idx, "iou": max_iou, "label_correct": label_correct }) else: matches.append({ "pred_idx": max_pred_idx, "gt_idx": max_gt_idx, "iou": 0, "label_correct": label_correct }) # Remove matched boxes from the unmatched list unmatched_preds.remove(max_pred_idx) unmatched_gts.remove(max_gt_idx) # Calculate position accuracy score (average IoU) position_score = sum(m["iou"] for m in matches) / len(gt_boxes) if matches else 0.0 # Calculate label accuracy score label_score = sum(1.0 for m in matches if m["label_correct"]) / len(gt_boxes) if matches else 0.0 # Calculate completeness score (considering missed and false detections) # Miss rate = number of unmatched ground truth boxes / total number of ground truth boxes # False alarm rate = number of unmatched predicted boxes / total number of predicted boxes miss_rate = len(unmatched_gts) / len(gt_boxes) false_alarm_rate = len(unmatched_preds) / len(pred_boxes) if pred_boxes else 0.0 # Completeness score = 1 - (miss rate + false alarm rate) / 2 completeness_score = 1.0 - (miss_rate + false_alarm_rate) / 2.0 # Calculate the final comprehensive score final_score = ( alpha * position_score + beta * label_score + gamma * completeness_score ) / (alpha + beta + gamma) return final_score def cosine_reward(content, tokenizer, acc_reward, **kwargs): #https://arxiv.org/abs/2502.03373 min_len_value_wrong = 0.0 max_len_value_wrong = -0.5 min_len_value_correct = 1.0 max_len_value_correct = 0.5 cosine_max_len = 1024 # processing_class = AutoProcessor.from_pretrained(model_path) # tokenizer = processing_class.tokenizer gen_len = len(tokenizer.encode(content)) acc_reward = 1.0 is_correct = acc_reward >= 0.7 if is_correct: # Swap min/max for correct answers min_value = max_len_value_correct max_value = min_len_value_correct else: min_value = min_len_value_wrong max_value = max_len_value_wrong reward = max_value - (max_value - min_value) * (1 - math.cos(gen_len * math.pi / cosine_max_len)) / 2 return reward def repetition_reward(content, **kwargs): max_penalty = -1.0 if content == '': return 0.0 # First, try to extract explicitly marked JSON sections pattern = r'```json(.*?)```' json_match = re.search(pattern, content, re.DOTALL) if json_match: bbox_json = json_match.group(1).strip() else: # If no explicitly marked JSON is found, try to find any possible JSON sections pattern = r'```(.*?)```' json_match = re.search(pattern, content, re.DOTALL) bbox_json = json_match.group(1).strip() if json_match else None # If still not found, try to find possible JSON array sections if not bbox_json: pattern = r'\[\s*{.*?"bbox_2d".*?"label".*?}\s*\]' json_match = re.search(pattern, content, re.DOTALL) bbox_json = json_match.group(0) if json_match else None # Try to parse JSON data if bbox_json: try: # Try direct parsing data = json.loads(bbox_json) except json.JSONDecodeError: try: # If direct parsing fails, try using json_repair to repair repaired_json = repair_json(bbox_json) data = json.loads(repaired_json) except: # If repair also fails, switch to plain text processing data = None if data and isinstance(data, list): # Ensure data is in list format try: # For JSON data, set ngram_size to 1 ngram_size = 1 # Combine 'bbox_2d' and 'label' of each object into a string items = [] for item in data: if 'bbox_2d' in item and 'label' in item: items.append(f"{item['bbox_2d']}_{item['label']}") @staticmethod def zipngram(text: list, ngram_size: int): return zip(*[text[i:] for i in range(ngram_size)]) ngrams = set() total = 0 for ng in zipngram(items, ngram_size): ngrams.add(ng) total += 1 if total == 0: return 0.0 scaling = 1 - len(ngrams) / total reward = scaling * max_penalty return reward except KeyError: # If necessary keys are missing, switch to plain text processing pass # If no JSON section is found or JSON processing fails, treat as plain text ngram_size = 6 if len(content.split()) < ngram_size: return 0.0 @staticmethod def zipngram(text: str, ngram_size: int): words = text.lower().split() return zip(*[words[i:] for i in range(ngram_size)]) ngrams = set() total = 0 for ng in zipngram(content, ngram_size): ngrams.add(ng) total += 1 scaling = 1 - len(ngrams) / total reward = scaling * max_penalty return reward def repetition_rewards(completions, solution, **kwargs): contents = [completion[0]["content"] for completion in completions] rewards = [] for content, sol in zip(contents, solution): reward = repetition_reward(content) rewards.append(reward) if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") current_time = datetime.now().strftime("%d-%H-%M-%S-%f") image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None problem = kwargs.get("problem")[0] if reward <= 0.0: # this condition can be changed for debug with open(log_path+"_repetition.txt", "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n") f.write(f"image_path: {image_path}\n") f.write(f"problem: {problem}\n") f.write(f"Content: {content}\n") f.write(f"Solution: {sol}\n") return rewards def cosine_rewards(completions, solution, **kwargs): contents = [completion[0]["content"] for completion in completions] rewards = [] for content, sol in zip(contents, solution): clean_content = clean_text(content) sol = clean_text(sol) if sol == "none": if clean_content == "none": acc_reward = 1.0 else: acc_reward = 0.0 else: acc_reward = detection_score(clean_content, sol) reward = cosine_reward(content, tokenizer, acc_reward) rewards.append(reward) if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") current_time = datetime.now().strftime("%d-%H-%M-%S-%f") image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None problem = kwargs.get("problem")[0] if reward <=1.0: # this condition can be changed for debug with open(log_path+"_cosine.txt", "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n") f.write(f"image_path: {image_path}\n") f.write(f"problem: {problem}\n") f.write(f"Content: {content}\n") f.write(f"Solution: {sol}\n") return rewards def numeric_reward(content, sol, **kwargs): content = clean_text(content) sol = clean_text(sol) try: content, sol = float(content), float(sol) return 1.0 if content == sol else 0.0 except: return None def math_reward(content, sol, **kwargs): content = clean_text(content) sol = clean_text(sol) return compute_score(content, sol) def clean_text(text, exclue_chars=['\n', '\r']): # Extract content between and if present answer_matches = re.findall(r'(.*?)', text, re.DOTALL) if answer_matches: # Use the last match text = answer_matches[-1] for char in exclue_chars: if char in ['\n', '\r']: # If there is a space before the newline, remove the newline text = re.sub(r'(?<=\s)' + re.escape(char), '', text) # If there is no space before the newline, replace it with a space text = re.sub(r'(?(.*?)', sol) ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() # Extract answer from content if it has think/answer tags content_matches = re.findall(r'(.*?)', content, re.DOTALL) student_answer = content_matches[-1].strip() if content_matches else content.strip() # Try symbolic verification first for numeric answers try: answer = parse(student_answer) if float(verify(answer, parse(ground_truth))) > 0: reward = 1.0 except Exception: pass # Continue to next verification method if this fails # If symbolic verification failed, try string matching or fuzzy matching if reward == 0.0: try: # Check if ground truth contains numbers has_numbers = bool(re.search(r'\d', ground_truth)) # Check if it's a multiple choice question has_choices = extract_choice(ground_truth) if has_numbers: # For numeric answers, use exact matching reward = numeric_reward(student_answer, ground_truth) if reward is None: reward = ratio(clean_text(student_answer), clean_text(ground_truth)) elif has_choices: # For multiple choice, extract and compare choices correct_choice = has_choices.upper() student_choice = extract_choice(student_answer) if student_choice: reward = 1.0 if student_choice == correct_choice else 0.0 else: # For text answers, use fuzzy matching reward = ratio(clean_text(student_answer), clean_text(ground_truth)) except Exception: pass # Keep reward as 0.0 if all methods fail return reward def accuracy_reward(completions, solution, **kwargs): """Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching.""" contents = [completion[0]["content"] for completion in completions] rewards = [] for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")): # if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function if accu_reward_method == "mcq": reward = mcq_reward(content, sol) elif accu_reward_method == 'yes_no': reward = yes_no_reward(content, sol) elif accu_reward_method == 'llm': reward = llm_reward(content, sol) elif accu_reward_method == 'map': reward = map_reward(content, sol) elif accu_reward_method == 'math': reward = math_reward(content, sol) elif accu_reward_method == 'weighted_sum': clean_content = clean_text(content) sol = clean_text(sol) if sol == "none": if clean_content == "none": reward = 1.0 else: reward = 0.0 else: reward = detection_score(clean_content, sol) elif accu_reward_method == 'od_ap': reward = od_reward(content, sol) elif accu_reward_method == 'od_ap50': reward = od_reward(content, sol, score_type=1) elif accu_reward_method == 'odLength': reward = odLength_reward(content, sol) elif accu_reward_method == 'all_match': reward = all_match_reward(content, sol) else: reward = default_accuracy_reward(content, sol) rewards.append(reward) if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") current_time = datetime.now().strftime("%d-%H-%M-%S-%f") image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None problem = kwargs.get("problem")[0] if reward <= 1.0: # this condition can be changed for debug with open(log_path, "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n") f.write(f"accu_reward_method: {accu_reward_method}\n") f.write(f"image_path: {image_path}\n") f.write(f"problem: {problem}\n") f.write(f"Content: {content}\n") f.write(f"Solution: {sol}\n") return rewards def format_reward(completions, **kwargs): """Reward function that checks if the completion has a specific format.""" pattern = r".*?\s*.*?" completion_contents = [completion[0]["content"] for completion in completions] matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents] current_time = datetime.now().strftime("%d-%H-%M-%S-%f") if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Format reward -------------\n") for content, match in zip(completion_contents, matches): f.write(f"Content: {content}\n") f.write(f"Has format: {bool(match)}\n") return [1.0 if match else 0.0 for match in matches] reward_funcs_registry = { "accuracy": accuracy_reward, "format": format_reward, "length": cosine_rewards, "repetition": repetition_rewards, } @dataclass class GRPOModelConfig(ModelConfig): freeze_vision_modules: bool = False SYSTEM_PROMPT = ( "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " "process and answer are enclosed within and tags, respectively, i.e., " " reasoning process here answer here " ) def get_vlm_module(model_name_or_path): if "qwen" in model_name_or_path.lower(): return Qwen2VLModule else: raise ValueError(f"Unsupported model: {model_name_or_path}") def main(script_args, training_args, model_args): # Load the VLM module vlm_module_cls = get_vlm_module(model_args.model_name_or_path) print("using vlm module:", vlm_module_cls.__name__) question_prompt = vlm_module_cls.get_question_template(task_type=script_args.task_type) # Get reward functions if script_args.is_reward_customized_from_vlm_module: reward_funcs = [vlm_module_cls.select_reward_func(func, script_args.task_type) for func in script_args.reward_funcs] else: reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] print("reward_funcs:", reward_funcs) # Load the JSONL datasets import json from datasets import Dataset data_files = script_args.data_file_paths.split(":") image_folders = script_args.image_folders.split(":") if len(data_files) != len(image_folders): raise ValueError("Number of data files must match number of image folders") if script_args.reward_method is None: accu_reward_methods = ["default"] * len(data_files) else: accu_reward_methods = script_args.reward_method.split(":") assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}" if len(data_files) != len(image_folders): raise ValueError("Number of data files must match number of image folders") all_data = [] for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods): with open(data_file, 'r') as f: for line in f: item = json.loads(line) if 'image' in item: if isinstance(item['image'], str): # Store image path instead of loading the image item['image_path'] = [os.path.join(image_folder, item['image'])] del item['image'] # remove the image column so that it can be loaded later elif isinstance(item['image'], list): # if the image is a list, then it is a list of images (for multi-image input) item['image_path'] = [os.path.join(image_folder, image) for image in item['image']] del item['image'] # remove the image column so that it can be loaded later else: raise ValueError(f"Unsupported image type: {type(item['image'])}") # Remove immediate image loading item['problem'] = item['conversations'][0]['value'].replace('', '') # Handle solution that could be a float or string solution_value = item['conversations'][1]['value'] if isinstance(solution_value, str): item['solution'] = solution_value.replace('', '').replace('', '').strip() else: # If it's a float or other non-string type, keep it as is item['solution'] = str(solution_value) del item['conversations'] item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value all_data.append(item) dataset = Dataset.from_list(all_data) def make_conversation_from_jsonl(example): if 'image_path' in example and example['image_path'] is not None: assert all(os.path.exists(p) for p in example['image_path']), f"Image paths do not exist: {example['image_path']}" # Don't load image here, just store the path return { 'image_path': [p for p in example['image_path']], # Store path instead of loaded image 'problem': example['problem'], 'solution': f" {example['solution']} ", 'accu_reward_method': example['accu_reward_method'], 'prompt': [{ 'role': 'user', 'content': [ *({'type': 'image', 'text': None} for _ in range(len(example['image_path']))), {'type': 'text', 'text': question_prompt.format(Question=example['problem'])} ] }] } else: return { 'problem': example['problem'], 'solution': f" {example['solution']} ", 'accu_reward_method': example['accu_reward_method'], 'prompt': [{ 'role': 'user', 'content': [ {'type': 'text', 'text': question_prompt.format(Question=example['problem'])} ] }] } # Map the conversations dataset = dataset.map(make_conversation_from_jsonl, num_proc=8) # print(dataset[0]) # Split dataset for validation if requested splits = {'train': dataset} if script_args.val_split_ratio > 0: train_val_split = dataset.train_test_split( test_size=script_args.val_split_ratio ) splits['train'] = train_val_split['train'] splits['validation'] = train_val_split['test'] # Select trainer class based on vlm_trainer argument trainer_cls = VLMGRPOTrainer print("using trainer:", trainer_cls.__name__) initialize_tokenizer(model_args.model_name_or_path) # Initialize the GRPO trainer trainer = trainer_cls( model=model_args.model_name_or_path, reward_funcs=reward_funcs, args=training_args, vlm_module=vlm_module_cls(), train_dataset=splits['train'], eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_args), freeze_vision_modules=model_args.freeze_vision_modules, attn_implementation=model_args.attn_implementation, max_pixels=script_args.max_pixels, min_pixels=script_args.min_pixels, max_anyres_num=script_args.max_anyres_num, ) # Train and push the model to the Hub if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub() if __name__ == "__main__": parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() if training_args.deepspeed and "zero3" in training_args.deepspeed: print("zero3 is used, qwen2_5vl forward monkey patch is applied") monkey_patch_qwen2_5vl_forward() main(script_args, training_args, model_args) ================================================ FILE: src/open-r1-multimodal/src/open_r1/qwen2_5vl_monkey_patch.py ================================================ # ----------------------- Fix the flash attention bug in the current version of transformers ----------------------- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func import torch from typing import Tuple, Optional def qwen2_5vl_vision_flash_attn_forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) # print(111, 222, 333, 444, 555, 666, 777, 888, 999) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos().float() sin = emb.sin().float() else: cos, sin = position_embeddings # Add this cos = cos.to(torch.float) sin = sin.to(torch.float) q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) q = q.squeeze(0) k = k.squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( seq_length, -1 ) attn_output = self.proj(attn_output) return attn_output def monkey_patch_qwen2_5vl_flash_attn(): Qwen2_5_VLVisionFlashAttention2.forward = qwen2_5vl_vision_flash_attn_forward # ----------------------- Fix the process pending bug when using data mixture of image-text data and pure-text under deepseed zero3----------------------- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast from typing import List, Union from torch.nn import CrossEntropyLoss from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration def qwen2_5vl_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) has_images_global = False if pixel_values is not None: has_images_local = torch.tensor(1, device=input_ids.device) else: has_images_local = torch.tensor(0, device=input_ids.device) # Use all_reduce to ensure all GPUs know if there are images to process torch.distributed.all_reduce(has_images_local, op=torch.distributed.ReduceOp.MAX) has_images_global = has_images_local.item() > 0 # If there are image inputs globally, ensure all GPUs call the visual model if has_images_global: if pixel_values is not None: pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) mask = input_ids == self.config.image_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_mask = mask_expanded.to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) else: with torch.no_grad(): # Create a dummy image data for triggering parameter synchronization dummy_pixel_values = torch.zeros((4, 1176), device=input_ids.device, dtype=self.visual.dtype) dummy_grid_thw = torch.tensor([[1, 2, 2]], device=input_ids.device) _ = self.visual(dummy_pixel_values, grid_thw=dummy_grid_thw) # Currently, video processing is not handled. if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) mask = input_ids == self.config.video_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) video_mask = mask_expanded.to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if ( (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None or (past_key_values is None or past_key_values.get_seq_length() == 0) ): position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = ( (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return Qwen2_5_VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=self.rope_deltas, ) def monkey_patch_qwen2_5vl_forward(): Qwen2_5_VLForConditionalGeneration.forward = qwen2_5vl_forward # ----------------------- Set the Weights only as False in torch.load (In Pytorch 2.6, this is default as True)----------------------- from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.utils import logger, log_dist def weigths_only_load(self, path: str, map_location=None): logger.info(f"[Torch] Loading checkpoint from {path}...") partition = torch.load(path, map_location=map_location, weights_only=False) logger.info(f"[Torch] Loaded checkpoint from {path}.") return partition def monkey_patch_torch_load(): TorchCheckpointEngine.load = weigths_only_load ================================================ FILE: src/open-r1-multimodal/src/open_r1/trainer/__init__.py ================================================ from .grpo_trainer import VLMGRPOTrainer from .grpo_config import GRPOConfig __all__ = ["VLMGRPOTrainer"] ================================================ FILE: src/open-r1-multimodal/src/open_r1/trainer/grpo_config.py ================================================ # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field from typing import Optional from transformers import TrainingArguments @dataclass class GRPOConfig(TrainingArguments): r""" Configuration class for the [`GRPOTrainer`]. Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the [`~transformers.TrainingArguments`] documentation. Using [`~transformers.HfArgumentParser`] we can turn this class into [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the command line. Parameters: > Parameters that control the model and reference model model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` argument of the [`GRPOTrainer`] is provided as a string. > Parameters that control the data preprocessing remove_unused_columns (`bool`, *optional*, defaults to `False`): Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. num_generations (`int` or `None`, *optional*, defaults to `8`): Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value. max_completion_length (`int` or `None`, *optional*, defaults to `256`): Maximum length of the generated completion. ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, improving generation speed. However, disabling this option allows training models that exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible with vLLM generation. > Parameters that control generation temperature (`float`, defaults to `0.9`): Temperature for sampling. The higher the temperature, the more random the completions. top_p (`float`, *optional*, defaults to `1.0`): Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to `1.0` to consider all tokens. top_k (`int` or `None`, *optional*, defaults to `50`): Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is disabled. min_p (`float` or `None`, *optional*, defaults to `None`): Minimum token probability, which will be scaled by the probability of the most likely token. It must be a value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. repetition_penalty (`float`, *optional*, defaults to `1.0`): Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat tokens. cache_implementation (`str` or `None`, *optional*, defaults to `None`): Implementation of the cache method for faster generation when use_vllm is set to False. > Parameters that control generation acceleration powered by vLLM use_vllm (`bool`, *optional*, defaults to `False`): Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). vllm_device (`str`, *optional*, defaults to `"auto"`): Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will automatically select the next available GPU after the last one used for training. This assumes that training has not already occupied all available GPUs. If only one device is available, the device will be shared between both training and vLLM. vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during initialization. vllm_dtype (`str`, *optional*, defaults to `"auto"`): Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined based on the model configuration. Find the supported values in the vLLM documentation. vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`): If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context size, which might be much larger than the KV cache, leading to inefficiencies. vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`): Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware support this feature. vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. > Parameters that control the training learning_rate (`float`, *optional*, defaults to `1e-6`): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. beta (`float`, *optional*, defaults to `0.04`): KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training speed, but may be numerically unstable for long training runs. num_iterations (`int`, *optional*, defaults to `1`): Number of iterations per batch (denoted as μ in the algorithm). epsilon (`float`, *optional*, defaults to `0.2`): Epsilon value for clipping. epsilon_high (`float` or `None`, *optional*, defaults to `None`): Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. reward_weights (`list[float]` or `None`, *optional*, defaults to `None`): Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`. sync_ref_model (`bool`, *optional*, defaults to `False`): Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using the `ref_model_mixup_alpha` parameter. This synchronization originites from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper. ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix between the current policy and the previous reference policy during updates. The reference policy is updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`. ref_model_sync_steps (`int`, *optional*, defaults to `512`): τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how frequently the current policy is synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`. > Parameters that control the logging log_completions (`bool`, *optional*, defaults to `False`): Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. """ # Parameters that control the model and reference model model_init_kwargs: Optional[dict] = field( default=None, metadata={ "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " "argument of the `GRPOTrainer` is provided as a string." }, ) # Parameters that control the data preprocessing # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on # additional columns to compute the reward remove_unused_columns: Optional[bool] = field( default=False, metadata={ "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function " "that requires any column other than 'prompts' and 'completions', you should keep this to `False`." }, ) max_prompt_length: Optional[int] = field( default=512, metadata={ "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." }, ) num_generations: Optional[int] = field( default=8, metadata={ "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) " "must be divisible by this value." }, ) max_completion_length: Optional[int] = field( default=256, metadata={"help": "Maximum length of the generated completion."}, ) ds3_gather_for_generation: bool = field( default=True, metadata={ "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " "generation, improving generation speed. However, disabling this option allows training models that " "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " "is not compatible with vLLM generation." }, ) # Parameters that control generation temperature: float = field( default=0.9, metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, ) top_p: float = field( default=1.0, metadata={ "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " "Set to 1.0 to consider all tokens." }, ) top_k: Optional[int] = field( default=50, metadata={ "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, " "top-k-filtering is disabled." }, ) min_p: Optional[float] = field( default=None, metadata={ "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." }, ) repetition_penalty: float = field( default=1.0, metadata={ "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " "to repeat tokens." }, ) cache_implementation: Optional[str] = field( default=None, metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, ) # Parameters that control generation acceleration powered by vLLM use_vllm: Optional[bool] = field( default=False, metadata={ "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept " "unused for training, as vLLM will require one for generation. vLLM must be installed " "(`pip install vllm`)." }, ) vllm_device: Optional[str] = field( default="auto", metadata={ "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system " "will automatically select the next available GPU after the last one used for training. This assumes " "that training has not already occupied all available GPUs." }, ) vllm_gpu_memory_utilization: float = field( default=0.9, metadata={ "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " "size and thus improve the model's throughput. However, if the value is too high, it may cause " "out-of-memory (OOM) errors during initialization." }, ) vllm_dtype: Optional[str] = field( default="auto", metadata={ "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " "determined based on the model configuration. Find the supported values in the vLLM documentation." }, ) vllm_max_model_len: Optional[int] = field( default=None, metadata={ "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced " "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model " "context size, which might be much larger than the KV cache, leading to inefficiencies." }, ) vllm_enable_prefix_caching: Optional[bool] = field( default=True, metadata={ "help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and " "the hardware support this feature." }, ) vllm_guided_decoding_regex: Optional[str] = field( default=None, metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, ) # Parameters that control the training learning_rate: float = field( default=1e-6, metadata={ "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " "`transformers.TrainingArguments`." }, ) beta: float = field( default=0.04, metadata={ "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving " "training speed, but may be numerically unstable for long training runs." }, ) num_iterations: int = field( default=1, metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."}, ) epsilon: float = field( default=0.2, metadata={"help": "Epsilon value for clipping."}, ) epsilon_high: Optional[float] = field( default=None, metadata={ "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`." }, ) reward_weights: Optional[list[float]] = field( default=None, metadata={ "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all " "rewards are weighted equally with weight `1.0`." }, ) sync_ref_model: bool = field( default=False, metadata={ "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " "steps, using the `ref_model_mixup_alpha` parameter." }, ) ref_model_mixup_alpha: float = field( default=0.6, metadata={ "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " "previous reference policy during updates. The reference policy is updated according to the equation: " "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." }, ) ref_model_sync_steps: int = field( default=512, metadata={ "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." }, ) # Parameters that control the logging log_completions: bool = field( default=False, metadata={ "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is " "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." }, ) ================================================ FILE: src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py ================================================ # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import textwrap from collections import defaultdict from typing import Any, Callable, Optional, Union, Sized from qwen_vl_utils import process_vision_info import torch import torch.utils.data import transformers from datasets import Dataset, IterableDataset from packaging import version from transformers import ( AriaForConditionalGeneration, AriaProcessor, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, Trainer, TrainerCallback, is_wandb_available, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils import is_peft_available from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from trl.trainer.grpo_config import GRPOConfig from trl.trainer.utils import generate_model_card, get_comet_experiment_url # from trl import GRPOTrainer from accelerate.utils import is_peft_model, set_seed import PIL.Image import copy from torch.utils.data import Sampler import warnings if is_peft_available(): from peft import PeftConfig, get_peft_model if is_wandb_available(): import wandb from vlm_modules.vlm_module import VLMBaseModule # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] class RepeatRandomSampler(Sampler): """ Sampler that repeats the indices of a dataset in a structured manner. Args: data_source (`Sized`): Dataset to sample from. mini_repeat_count (`int`): Number of times to repeat each index per batch. batch_size (`int`, *optional*, defaults to `1`): Number of unique indices per batch. repeat_count (`int`, *optional*, defaults to `1`): Number of times to repeat the full sampling process. seed (`int` or `None`, *optional*, defaults to `None`): Random seed for reproducibility. """ def __init__( self, data_source: Sized, mini_repeat_count: int, batch_size: int = 1, repeat_count: int = 1, seed: Optional[int] = None, ): self.data_source = data_source self.mini_repeat_count = mini_repeat_count self.batch_size = batch_size self.repeat_count = repeat_count self.num_samples = len(data_source) self.seed = seed self.generator = torch.Generator() if seed is not None: self.generator.manual_seed(seed) def __iter__(self): indexes = torch.randperm(self.num_samples, generator=self.generator).tolist() indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)] indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size] for chunk in indexes: for _ in range(self.repeat_count): for index in chunk: for _ in range(self.mini_repeat_count): yield index def __len__(self) -> int: return self.num_samples * self.mini_repeat_count * self.repeat_count class VLMGRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). Example: ```python from datasets import load_dataset from trl import GRPOTrainer dataset = load_dataset("trl-lib/tldr", split="train") trainer = GRPOTrainer( model="Qwen/Qwen2-0.5B-Instruct", reward_funcs="weqweasdas/RM-Gemma-2B", train_dataset=dataset, ) trainer.train() ``` Args: model (`Union[str, PreTrainedModel]`): Model to be trained. Can be either: - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a path to a *directory* containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments in `args.model_init_kwargs`. - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward functions with the prompts and completions and sum the rewards. Can be either: - A single reward function, such as: - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a path to a *directory* containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the keyword arguments in `args.model_init_kwargs`. - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. - A custom reward function: The function is provided with the prompts and the generated completions, plus any additional columns in the dataset. It should return a list of rewards. For more details, see [Using a custom reward function](#using-a-custom-reward-function). - A list of reward functions, where each item can independently be any of the above types. Mixing different types within the list (e.g., a string model ID and a custom reward function) is allowed. args ([`GRPOConfig`], *optional*, defaults to `None`): Configuration for this trainer. If `None`, a default configuration is used. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is ignored. The format of the samples can be either: - [Standard](dataset_formats#standard): Each sample contains plain text. - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role and content). eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): Processing class used to process the data. The padding side must be set to "left". If `None`, the processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`): Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: - A single processing class: Used when `reward_funcs` contains only one reward function. - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes` are ignored. callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] method. optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. """ def __init__( self, model: Union[str, PreTrainedModel], reward_funcs: Union[RewardFunc, list[RewardFunc]], args: GRPOConfig = None, vlm_module: VLMBaseModule = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, processing_class: Optional[PreTrainedTokenizerBase] = None, reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, freeze_vision_modules: Optional[bool] = False, attn_implementation: str = "flash_attention_2", torch_dtype: str = "bfloat16", **kwargs, ): # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path model_name = model_name.split("/")[-1] args = GRPOConfig(f"{model_name}-GRPO") self.vlm_module = vlm_module # Models # Trained model model_init_kwargs = args.model_init_kwargs or {} # FIXME # Remember to modify it in the invernvl model_init_kwargs["attn_implementation"] = attn_implementation if model_init_kwargs.get("torch_dtype") is None: model_init_kwargs["torch_dtype"] = torch_dtype assert isinstance(model, str), "model must be a string in the current implementation" model_id = model torch_dtype = model_init_kwargs.get("torch_dtype") if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: pass # torch_dtype is already a torch.dtype or "auto" or None elif isinstance(torch_dtype, str): # it's a str, but not "auto" torch_dtype = getattr(torch, torch_dtype) else: raise ValueError( "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." ) # Disable caching if gradient checkpointing is enabled (not supported) model_init_kwargs["use_cache"] = ( False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") ) model_cls = self.vlm_module.get_model_class(model_id, model_init_kwargs) model = model_cls.from_pretrained(model_id, **model_init_kwargs) # for name, param in model.named_parameters(): # if not param.requires_grad: # print(f"Frozen: {name}") # else: # print(f"trainable:{name}") # LoRA self.vision_modules_keywords = self.vlm_module.get_vision_modules_keywords() if peft_config is not None: print("Applying LoRA...") def find_all_linear_names(model, multimodal_keywords): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): # LoRA is not applied to the vision modules if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue if isinstance(module, cls): lora_module_names.add(name) for m in lora_module_names: # needed for 16-bit if "embed_tokens" in m: lora_module_names.remove(m) return list(lora_module_names) target_modules = find_all_linear_names(model, self.vision_modules_keywords) peft_config.target_modules = target_modules model = get_peft_model(model, peft_config) # Freeze vision modules if freeze_vision_modules: print("Freezing vision modules...") for n, p in model.named_parameters(): if any(keyword in n for keyword in self.vision_modules_keywords): p.requires_grad = False # Compute the number of trainable parameters and print the parameter that is trainable # for name, param in model.named_parameters(): # print(name, param.requires_grad) trainable_params = [p for p in model.parameters() if p.requires_grad] total_params = sum(p.numel() for p in trainable_params) # for n, p in model.named_parameters(): # if p.requires_grad: # print(n, p.shape) print(f"Total trainable parameters: {total_params}") # Enable gradient checkpointing if requested if args.gradient_checkpointing: model = self._enable_gradient_checkpointing(model, args) # Reference model self.beta = args.beta if self.beta == 0.0: # If beta is 0.0, the reference model is not needed self.ref_model = None elif is_deepspeed_zero3_enabled(): self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs) elif is_peft_model(model): # If PEFT is used, the reference model is not needed since the adapter can be disabled # to revert to the initial model. self.ref_model = None else: # If PEFT configuration is not provided, create a reference model based on the initial model. self.ref_model = create_reference_model(model) if processing_class is None: tokenizer = AutoTokenizer.from_pretrained( model_id, local_files_only=False, use_fast=True, trust_remote_code=True ) processing_cls = self.vlm_module.get_processing_class() processing_class = processing_cls.from_pretrained(model_id, trust_remote_code=model_init_kwargs.get("trust_remote_code", None)) processing_class.tokenizer = tokenizer for component, processing_keyword in self.vlm_module.get_custom_processing_keywords(): if processing_keyword in kwargs: # If we cannot find component in processing_class, return the processing_class itself processing_component = getattr(processing_class, component, processing_class) setattr(processing_component, processing_keyword, kwargs[processing_keyword]) if getattr(processing_class, "tokenizer", None) is not None: pad_token_id = processing_class.tokenizer.pad_token_id processing_class.pad_token_id = pad_token_id processing_class.eos_token_id = processing_class.tokenizer.eos_token_id else: assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute" pad_token_id = processing_class.pad_token_id self.vlm_module.post_model_init(model, processing_class) self.vlm_module.post_model_init(self.ref_model, processing_class) # Reward functions if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] for i, reward_func in enumerate(reward_funcs): if isinstance(reward_func, str): reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( reward_func, num_labels=1, **model_init_kwargs ) self.reward_funcs = reward_funcs # Reward processing class if reward_processing_classes is None: reward_processing_classes = [None] * len(reward_funcs) elif not isinstance(reward_processing_classes, list): reward_processing_classes = [reward_processing_classes] else: if len(reward_processing_classes) != len(reward_funcs): raise ValueError("The number of reward processing classes must match the number of reward functions.") for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): if isinstance(reward_func, PreTrainedModel): if reward_processing_class is None: reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) if reward_processing_class.pad_token_id is None: reward_processing_class.pad_token = reward_processing_class.eos_token # The reward model computes the reward for the latest non-padded token in the input sequence. # So it's important to set the pad token ID to the padding token ID of the processing class. reward_func.config.pad_token_id = reward_processing_class.pad_token_id reward_processing_classes[i] = reward_processing_class self.reward_processing_classes = reward_processing_classes # Data collator def data_collator(features): # No data collation is needed in GRPO return features # Training arguments self.max_prompt_length = args.max_prompt_length self.max_prompt_length = None if args.max_prompt_length is not None: warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None") self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.generation_config = GenerationConfig( max_new_tokens=self.max_completion_length, do_sample=True, temperature=1, pad_token_id=pad_token_id, ) if hasattr(self.vlm_module, "get_eos_token_id"): # For InternVL self.generation_config.eos_token_id = self.vlm_module.get_eos_token_id(processing_class) self.beta = args.beta self.epsilon_low = args.epsilon self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon # Multi-step self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle self._step = 0 # Buffer the batch to reuse generated outputs across multiple updates self._buffered_inputs = [None] * args.gradient_accumulation_steps # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. # This acts as a flag to indicate that the warning has already been issued. model.warnings_issued["estimate_tokens"] = True # Initialize the metrics self._metrics = defaultdict(list) super().__init__( model=model, args=args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, callbacks=callbacks, optimizers=optimizers, ) # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations num_processes = self.accelerator.num_processes global_batch_size = args.per_device_train_batch_size * num_processes possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] if self.num_generations not in possible_values: raise ValueError( f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " f"batch size, the valid values for the number of generations are: {possible_values}." ) if self.args.eval_strategy != "no": global_batch_size = args.per_device_eval_batch_size * num_processes possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] if self.num_generations not in possible_values: raise ValueError( f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " f"eval batch size, the valid values for the number of generations are: {possible_values}." ) # Ensure each process receives a unique seed to prevent duplicate completions when generating with # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but # it's safer to set it in all cases. set_seed(args.seed, device_specific=True) # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set # self.model_accepts_loss_kwargs to False to enable scaling. self.model_accepts_loss_kwargs = False if self.ref_model is not None: # if self.is_deepspeed_enabled: if is_deepspeed_zero3_enabled(): self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, PreTrainedModel): self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel: """Enables gradient checkpointing for the model.""" # Ensure use_cache is disabled model.config.use_cache = False # Enable gradient checkpointing on the base model for PEFT if is_peft_model(model): model.base_model.gradient_checkpointing_enable() # Enable gradient checkpointing for non-PEFT models else: if getattr(model, "language_model", None) is not None: # For InternVL; these operations are copied from the original training script of InternVL model.language_model.config.use_cache = False model.vision_model.gradient_checkpointing = True model.vision_model.encoder.gradient_checkpointing = True model.language_model._set_gradient_checkpointing() # This line is necessary, otherwise the `model.gradient_checkpointing_enable()` will be executed during the training process, leading to an error since InternVL does not support this operation. args.gradient_checkpointing = False else: model.gradient_checkpointing_enable() gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} use_reentrant = ( "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] ) if use_reentrant: model.enable_input_require_grads() return model def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. # By default, this method sets `self._signature_columns` to the model's expected inputs. # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. # Instead, we set them to the columns expected by the `training_step` method, hence the override. if self._signature_columns is None: self._signature_columns = ["prompt"] # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs): logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits # (B, L, V) logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. per_token_logps = [] for logits_row, input_ids_row in zip(logits, input_ids): log_probs = logits_row.log_softmax(dim=-1) token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) per_token_logps.append(token_log_prob) return torch.stack(per_token_logps) def _prepare_inputs(self, inputs): # Simple pass-through, just like original return inputs def _get_key_from_inputs(self, x, key): ele = x.get(key, None) assert ele is not None, f"The key {key} is not found in the input" if isinstance(ele, list): return [e for e in ele] else: return [ele] def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device prompts = [x["prompt"] for x in inputs] prompts_text = self.vlm_module.prepare_prompt(self.processing_class, inputs) # Handle both pre-loaded images and image paths images = [] for x in inputs: if "image" in x: imgs = self._get_key_from_inputs(x, "image") elif "image_path" in x and x["image_path"] is not None: imgs = [PIL.Image.open(p) for p in self._get_key_from_inputs(x, "image_path")] else: imgs = [] for img in imgs: try: # Ensure minimum dimensions of 28 pixels w, h = img.size if w < 28 or h < 28: # Calculate new dimensions maintaining aspect ratio if w < h: new_w = 28 new_h = int(h * (28/w)) else: new_h = 28 new_w = int(w * (28/h)) img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS) except: pass images.append(img) prompt_inputs = self.vlm_module.prepare_model_inputs( self.processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False, ) prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] # Generate completions with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: generate_returned_result = unwrapped_model.generate( **{k: v for k, v in prompt_inputs.items() if k not in self.vlm_module.get_non_generate_params()}, generation_config=self.generation_config, ) prompt_length = prompt_ids.size(1) if not self.vlm_module.is_embeds_input(): prompt_completion_ids = generate_returned_result prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] else: # In this case, the input of the LLM backbone is the embedding of the combination of the image and text prompt # So the returned result of the `generate` method only contains the completion ids completion_ids = generate_returned_result prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # Mask everything after the first EOS token is_eos = completion_ids == self.processing_class.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) # Get the multimodal inputs multimodal_keywords = self.vlm_module.get_custom_multimodal_keywords() multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords} with torch.no_grad(): # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its # computation here, and use per_token_logps.detach() instead. if self.num_iterations > 1: old_per_token_logps = self._get_per_token_logps( model, prompt_completion_ids, attention_mask, **multimodal_inputs ) old_per_token_logps = old_per_token_logps[:, prompt_length - 1:] else: old_per_token_logps = None if self.beta == 0.0: ref_per_token_logps = None elif self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs ) else: with self.accelerator.unwrap_model(model).disable_adapter(): ref_per_token_logps = self._get_per_token_logps( model, prompt_completion_ids, attention_mask, **multimodal_inputs ) if ref_per_token_logps is not None: ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:] # Decode the generated completions completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [[{"role": "assistant", "content": completion}] for completion in completions] # Compute the rewards # No need to duplicate prompts as we're not generating multiple completions per prompt rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class) in enumerate( zip(self.reward_funcs, self.reward_processing_classes) ): if isinstance(reward_func, PreTrainedModel): if is_conversational(inputs[0]): messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: texts = [p + c for p, c in zip(prompts, completions)] reward_inputs = reward_processing_class( texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False ) reward_inputs = super()._prepare_inputs(reward_inputs) with torch.inference_mode(): rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) else: # Repeat all input columns (but "prompt" and "completion") to match the number of generations reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} for key in reward_kwargs: for example in inputs: # No need to duplicate prompts as we're not generating multiple completions per prompt # reward_kwargs[key].extend([example[key]] * self.num_generations) reward_kwargs[key].extend([example[key]]) output_reward_func = reward_func(prompts=prompts, completions=completions,**reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) # Gather rewards across processes rewards_per_func = self.accelerator.gather(rewards_per_func) # Sum the rewards from all reward functions rewards = rewards_per_func.sum(dim=1) # Compute grouped-wise rewards # Each group consists of num_generations completions for the same prompt mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) # Normalize the rewards to compute the advantages mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) # Get only the local slice of advantages process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) advantages = advantages[process_slice] # Log the metrics completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length) reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, PreTrainedModel): reward_func_name = reward_func.config._name_or_path.split("/")[-1] else: reward_func_name = reward_func.__name__ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) return { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, "completion_ids": completion_ids, "completion_mask": completion_mask, "old_per_token_logps": old_per_token_logps, "ref_per_token_logps": ref_per_token_logps, "advantages": advantages, "multimodal_inputs": multimodal_inputs } def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") # Check if we need to generate new completions or use buffered ones if self.state.global_step % self.num_iterations == 0: inputs = self._generate_and_score_completions(inputs, model) self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs else: inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] self._step += 1 # Get the prepared inputs prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] multimodal_inputs = inputs["multimodal_inputs"] # Concatenate for full sequence input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # Get the current policy's log probabilities per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs) # Get rid of the prompt (-1 because of the shift done in get_per_token_logps) per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:] # Get the advantages from inputs advantages = inputs["advantages"] # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation # and use per_token_logps.detach() instead old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach() # Compute the policy ratio and clipped version coef_1 = torch.exp(per_token_logps - old_per_token_logps) coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) per_token_loss1 = coef_1 * advantages.unsqueeze(1) per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) # Add KL penalty if beta > 0 if self.beta > 0: ref_per_token_logps = inputs["ref_per_token_logps"] per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 per_token_loss = per_token_loss + self.beta * per_token_kl # Log KL divergence mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) # Compute final loss loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Log clip ratio is_clipped = (per_token_loss1 < per_token_loss2).float() clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum() self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item()) return loss def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics logs = {**logs, **metrics} if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): super().log(logs, start_time) else: # transformers<=4.46 super().log(logs) self._metrics.clear() def create_model_card( self, model_name: Optional[str] = None, dataset_name: Optional[str] = None, tags: Union[str, list[str], None] = None, ): """ Creates a draft of a model card using the information available to the `Trainer`. Args: model_name (`str` or `None`, *optional*, defaults to `None`): Name of the model. dataset_name (`str` or `None`, *optional*, defaults to `None`): Name of the dataset used for training. tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): Tags to be associated with the model card. """ if not self.is_world_process_zero(): return if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): base_model = self.model.config._name_or_path else: base_model = None tags = tags or [] if isinstance(tags, str): tags = [tags] if hasattr(self.model.config, "unsloth_version"): tags.append("unsloth") citation = textwrap.dedent( """\ @article{zhihong2024deepseekmath, title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, year = 2024, eprint = {arXiv:2402.03300}, """ ) model_card = generate_model_card( base_model=base_model, model_name=model_name, hub_model_id=self.hub_model_id, dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), trainer_name="GRPO", trainer_citation=citation, paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", paper_id="2402.03300", ) model_card.save(os.path.join(self.args.output_dir, "README.md")) def _get_train_sampler(self) -> Sampler: """Returns a sampler that ensures proper data sampling for GRPO training.""" effective_batch_size = ( self.args.per_device_train_batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps ) return RepeatRandomSampler( data_source=self.train_dataset, mini_repeat_count=self.num_generations, batch_size=effective_batch_size // self.num_generations, repeat_count=self.num_iterations, seed=self.args.seed, ) def _get_eval_sampler(self, eval_dataset) -> Sampler: """Returns a sampler for evaluation.""" return RepeatRandomSampler( data_source=eval_dataset, mini_repeat_count=self.num_generations, seed=self.args.seed, ) ================================================ FILE: src/open-r1-multimodal/src/open_r1/utils/__init__.py ================================================ ================================================ FILE: src/open-r1-multimodal/src/open_r1/utils/callbacks.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess from typing import List from transformers import TrainerCallback from transformers.trainer_callback import TrainerControl, TrainerState from transformers.training_args import TrainingArguments from .evaluation import run_benchmark_jobs from .hub import push_to_hub_revision def is_slurm_available() -> bool: # returns true if a slurm queueing system is available try: subprocess.run(["sinfo"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return True except FileNotFoundError: return False class DummyConfig: def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) class PushToHubRevisionCallback(TrainerCallback): def __init__(self, model_config) -> None: self.model_config = model_config def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): if state.is_world_process_zero: global_step = state.global_step # WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken, so I do this workaround # Also if you instantiate a new SFTConfig, the accelerator dist state will be broken dummy_config = DummyConfig( hub_model_id=args.hub_model_id, hub_model_revision=f"{args.hub_model_revision}-step-{global_step:09d}", output_dir=f"{args.output_dir}/checkpoint-{global_step}", system_prompt=args.system_prompt, ) future = push_to_hub_revision( dummy_config, extra_ignore_patterns=["*.pt"] ) # don't push the optimizer states if is_slurm_available(): dummy_config.benchmarks = args.benchmarks def run_benchmark_callback(_): print(f"Checkpoint {global_step} pushed to hub.") run_benchmark_jobs(dummy_config, self.model_config) future.add_done_callback(run_benchmark_callback) CALLBACKS = { "push_to_hub_revision": PushToHubRevisionCallback, } def get_callbacks(train_config, model_config) -> List[TrainerCallback]: callbacks = [] for callback_name in train_config.callbacks: if callback_name not in CALLBACKS: raise ValueError(f"Callback {callback_name} not found in CALLBACKS.") callbacks.append(CALLBACKS[callback_name](model_config)) return callbacks ================================================ FILE: src/open-r1-multimodal/src/open_r1/utils/evaluation.py ================================================ import subprocess from typing import TYPE_CHECKING, Dict, Union from .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id if TYPE_CHECKING: from trl import GRPOConfig, SFTConfig, ModelConfig import os # We need a special environment setup to launch vLLM from within Slurm training jobs. # - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105 # - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269 user_home_directory = os.path.expanduser("~") VLLM_SLURM_PREFIX = [ "env", "-i", "bash", "-c", f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch ", ] def register_lighteval_task( configs: Dict[str, str], eval_suite: str, task_name: str, task_list: str, num_fewshot: int = 0 ): """Registers a LightEval task configuration. - Core tasks can be added from this table: https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl - Custom tasks that require their own metrics / scripts, should be stored in scripts/evaluation/extended_lighteval_tasks Args: configs (Dict[str, str]): The dictionary to store the task configuration. eval_suite (str, optional): The evaluation suite. task_name (str): The name of the task. task_list (str): The comma-separated list of tasks in the format "extended|{task_name}|{num_fewshot}|0" or "lighteval|{task_name}|{num_fewshot}|0". num_fewshot (int, optional): The number of few-shot examples. Defaults to 0. is_custom_task (bool, optional): Whether the task is a custom task. Defaults to False. """ # Format task list in lighteval format task_list = ",".join(f"{eval_suite}|{task}|{num_fewshot}|0" for task in task_list.split(",")) configs[task_name] = task_list LIGHTEVAL_TASKS = {} register_lighteval_task(LIGHTEVAL_TASKS, "custom", "math_500", "math_500", 0) register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime24", "aime24", 0) register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime25_part1", "aime25:part1", 0) register_lighteval_task(LIGHTEVAL_TASKS, "custom", "gpqa", "gpqa:diamond", 0) def get_lighteval_tasks(): return list(LIGHTEVAL_TASKS.keys()) SUPPORTED_BENCHMARKS = get_lighteval_tasks() def run_lighteval_job( benchmark: str, training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig" ) -> None: task_list = LIGHTEVAL_TASKS[benchmark] model_name = training_args.hub_model_id model_revision = training_args.hub_model_revision # For large models >= 30b params or those running the MATH benchmark, we need to shard them across the GPUs to avoid OOM num_gpus = get_gpu_count_for_vllm(model_name, model_revision) if get_param_count_from_repo_id(model_name) >= 30_000_000_000: tensor_parallel = True else: tensor_parallel = False cmd = VLLM_SLURM_PREFIX.copy() cmd_args = [ f"--gres=gpu:{num_gpus}", f"--job-name=or1_{benchmark}_{model_name.split('/')[-1]}_{model_revision}", "slurm/evaluate.slurm", benchmark, f'"{task_list}"', model_name, model_revision, f"{tensor_parallel}", f"{model_args.trust_remote_code}", ] if training_args.system_prompt is not None: cmd_args.append(f"--system_prompt={training_args.system_prompt}") cmd[-1] += " " + " ".join(cmd_args) subprocess.run(cmd, check=True) def run_benchmark_jobs(training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig") -> None: benchmarks = training_args.benchmarks if len(benchmarks) == 1 and benchmarks[0] == "all": benchmarks = get_lighteval_tasks() # Evaluate on all supported benchmarks. Later we may want to include a `chat` option # that just evaluates on `ifeval` and `mt_bench` etc. for benchmark in benchmarks: print(f"Launching benchmark `{benchmark}`") if benchmark in get_lighteval_tasks(): run_lighteval_job(benchmark, training_args, model_args) else: raise ValueError(f"Unknown benchmark {benchmark}") ================================================ FILE: src/open-r1-multimodal/src/open_r1/utils/hub.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import re from concurrent.futures import Future from transformers import AutoConfig from huggingface_hub import ( create_branch, create_repo, get_safetensors_metadata, list_repo_commits, list_repo_files, list_repo_refs, repo_exists, upload_folder, ) from trl import GRPOConfig, SFTConfig logger = logging.getLogger(__name__) def push_to_hub_revision(training_args: SFTConfig | GRPOConfig, extra_ignore_patterns=[]) -> Future: """Pushes the model to branch on a Hub repo.""" # Create a repo if it doesn't exist yet repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True) # Get initial commit to branch from initial_commit = list_repo_commits(training_args.hub_model_id)[-1] # Now create the branch we'll be pushing to create_branch( repo_id=training_args.hub_model_id, branch=training_args.hub_model_revision, revision=initial_commit.commit_id, exist_ok=True, ) logger.info(f"Created target repo at {repo_url}") logger.info(f"Pushing to the Hub revision {training_args.hub_model_revision}...") ignore_patterns = ["checkpoint-*", "*.pth"] ignore_patterns.extend(extra_ignore_patterns) future = upload_folder( repo_id=training_args.hub_model_id, folder_path=training_args.output_dir, revision=training_args.hub_model_revision, commit_message=f"Add {training_args.hub_model_revision} checkpoint", ignore_patterns=ignore_patterns, run_as_future=True, ) logger.info(f"Pushed to {repo_url} revision {training_args.hub_model_revision} successfully!") return future def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig): """Checks if a given Hub revision exists.""" if repo_exists(training_args.hub_model_id): if training_args.push_to_hub_revision is True: # First check if the revision exists revisions = [rev.name for rev in list_repo_refs(training_args.hub_model_id).branches] # If the revision exists, we next check it has a README file if training_args.hub_model_revision in revisions: repo_files = list_repo_files( repo_id=training_args.hub_model_id, revision=training_args.hub_model_revision ) if "README.md" in repo_files and training_args.overwrite_hub_revision is False: raise ValueError( f"Revision {training_args.hub_model_revision} already exists. " "Use --overwrite_hub_revision to overwrite it." ) def get_param_count_from_repo_id(repo_id: str) -> int: """Function to get model param counts from safetensors metadata or find patterns like 42m, 1.5b, 0.5m or products like 8x7b in a repo ID.""" try: metadata = get_safetensors_metadata(repo_id) return list(metadata.parameter_count.values())[0] except Exception: # Pattern to match products (like 8x7b) and single values (like 42m) pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])" matches = re.findall(pattern, repo_id.lower()) param_counts = [] for full_match, number1, _, _, number2, _, unit in matches: if number2: # If there's a second number, it's a product number = float(number1) * float(number2) else: # Otherwise, it's a single value number = float(number1) if unit == "b": number *= 1_000_000_000 # Convert to billion elif unit == "m": number *= 1_000_000 # Convert to million param_counts.append(number) if len(param_counts) > 0: # Return the largest number return int(max(param_counts)) else: # Return -1 if no match found return -1 def get_gpu_count_for_vllm(model_name: str, revision: str = "main", num_gpus: int = 8) -> int: """vLLM enforces a constraint that the number of attention heads must be divisible by the number of GPUs and 64 must be divisible by the number of GPUs. This function calculates the number of GPUs to use for decoding based on the number of attention heads in the model. """ config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True) # Get number of attention heads num_heads = config.num_attention_heads # Reduce num_gpus so that num_heads is divisible by num_gpus and 64 is divisible by num_gpus while num_heads % num_gpus != 0 or 64 % num_gpus != 0: logger.info(f"Reducing num_gpus from {num_gpus} to {num_gpus - 1} to make num_heads divisible by num_gpus") num_gpus -= 1 return num_gpus ================================================ FILE: src/open-r1-multimodal/src/open_r1/utils/math.py ================================================ from math_verify import parse, verify def compute_score(solution_str, ground_truth) -> float: retval = 0. if solution_str == ground_truth: return 1.0 if float(verify(parse(solution_str), parse(ground_truth))) > 0: return 1.0 try: answer = solution_str string_in_last_boxed = last_boxed_only_string(solution_str) if string_in_last_boxed is not None: answer = remove_boxed(string_in_last_boxed) if is_equiv(answer, ground_truth): return 1.0 except Exception as e: print(e) return retval def remove_boxed(s): if "\\boxed " in s: left = "\\boxed " assert s[:len(left)] == left return s[len(left):] left = "\\boxed{" assert s[:len(left)] == left assert s[-1] == "}" return s[len(left):-1] def last_boxed_only_string(string): idx = string.rfind("\\boxed") if "\\boxed " in string: return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx is None: retval = None else: retval = string[idx:right_brace_idx + 1] return retval # string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py def is_equiv(str1, str2, verbose=False): if str1 is None and str2 is None: print("WARNING: Both None") return True if str1 is None or str2 is None: return False try: ss1 = strip_string(str1) ss2 = strip_string(str2) if verbose: print(ss1, ss2) return ss1 == ss2 except Exception: return str1 == str2 def fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += "\\frac" if substr[0] == "{": new_str += substr else: try: assert len(substr) >= 2 except AssertionError: return string a = substr[0] b = substr[1] if b != "{": if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}{" + b + "}" + post_substr else: new_str += "{" + a + "}{" + b + "}" else: if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}" + b + post_substr else: new_str += "{" + a + "}" + b string = new_str return string def fix_a_slash_b(string): if len(string.split("/")) != 2: return string a = string.split("/")[0] b = string.split("/")[1] try: a = int(a) b = int(b) assert string == "{}/{}".format(a, b) new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" return new_string except AssertionError: return string def remove_right_units(string): # "\\text{ " only ever occurs (at least in the val set) when describing units if "\\text{ " in string: splits = string.split("\\text{ ") assert len(splits) == 2 return splits[0] else: return string def fix_sqrt(string): if "\\sqrt" not in string: return string splits = string.split("\\sqrt") new_string = splits[0] for split in splits[1:]: if split[0] != "{": a = split[0] new_substr = "\\sqrt{" + a + "}" + split[1:] else: new_substr = "\\sqrt" + split new_string += new_substr return new_string def strip_string(string): # linebreaks string = string.replace("\n", "") # remove inverse spaces string = string.replace("\\!", "") # replace \\ with \ string = string.replace("\\\\", "\\") # replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") # remove \left and \right string = string.replace("\\left", "") string = string.replace("\\right", "") # Remove circ (degrees) string = string.replace("^{\\circ}", "") string = string.replace("^\\circ", "") # remove dollar signs string = string.replace("\\$", "") # remove units (on the right) string = remove_right_units(string) # remove percentage string = string.replace("\\%", "") string = string.replace("\%", "") # noqa: W605 # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") # if empty, return empty string if len(string) == 0: return string if string[0] == ".": string = "0" + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split("=")) == 2: if len(string.split("=")[0]) <= 2: string = string.split("=")[1] # fix sqrt3 --> sqrt{3} string = fix_sqrt(string) # remove spaces string = string.replace(" ", "") # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = fix_fracs(string) # manually change 0.5 --> \frac{1}{2} if string == "0.5": string = "\\frac{1}{2}" # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = fix_a_slash_b(string) return string ================================================ FILE: src/open-r1-multimodal/src/open_r1/utils/pycocotools/coco.py ================================================ import json import time import matplotlib.pyplot as plt from matplotlib.collections import PatchCollection from matplotlib.patches import Polygon import numpy as np import copy import itertools #from . import mask as maskUtils import os from collections import defaultdict import sys PYTHON_VERSION = sys.version_info[0] if PYTHON_VERSION == 2: from urllib import urlretrieve elif PYTHON_VERSION == 3: from urllib.request import urlretrieve def _isArrayLike(obj): return hasattr(obj, '__iter__') and hasattr(obj, '__len__') class COCO: def __init__(self, annotation_file=None): """ Constructor of Microsoft COCO helper class for reading and visualizing annotations. :param annotation_file (str): location of annotation file :param image_folder (str): location to the folder that hosts images. :return: """ # load dataset self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict() self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) if not annotation_file == None: # print('loading annotations into memory...') tic = time.time() if type(annotation_file) == dict: dataset = annotation_file else: dataset = json.load(open(annotation_file, 'r')) assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset)) # print('Done (t={:0.2f}s)'.format(time.time()- tic)) self.dataset = dataset self.createIndex() def createIndex(self): # create index # print('creating index...') anns, cats, imgs = {}, {}, {} imgToAnns,catToImgs = defaultdict(list),defaultdict(list) if 'annotations' in self.dataset: for ann in self.dataset['annotations']: imgToAnns[ann['image_id']].append(ann) anns[ann['id']] = ann if 'images' in self.dataset: for img in self.dataset['images']: imgs[img['id']] = img if 'categories' in self.dataset: for cat in self.dataset['categories']: cats[cat['id']] = cat if 'annotations' in self.dataset and 'categories' in self.dataset: for ann in self.dataset['annotations']: catToImgs[ann['category_id']].append(ann['image_id']) # print('index created!') # create class members self.anns = anns self.imgToAnns = imgToAnns self.catToImgs = catToImgs self.imgs = imgs self.cats = cats def info(self): """ Print information about the annotation file. :return: """ for key, value in self.dataset['info'].items(): print('{}: {}'.format(key, value)) def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None): """ Get ann ids that satisfy given filter conditions. default skips that filter :param imgIds (int array) : get anns for given imgs catIds (int array) : get anns for given cats areaRng (float array) : get anns for given area range (e.g. [0 inf]) iscrowd (boolean) : get anns for given crowd label (False or True) :return: ids (int array) : integer array of ann ids """ imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] catIds = catIds if _isArrayLike(catIds) else [catIds] if len(imgIds) == len(catIds) == len(areaRng) == 0: anns = self.dataset['annotations'] else: if not len(imgIds) == 0: lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns] anns = list(itertools.chain.from_iterable(lists)) else: anns = self.dataset['annotations'] anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds] anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]] if not iscrowd == None: ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd] else: ids = [ann['id'] for ann in anns] return ids def getCatIds(self, catNms=[], supNms=[], catIds=[]): """ filtering parameters. default skips that filter. :param catNms (str array) : get cats for given cat names :param supNms (str array) : get cats for given supercategory names :param catIds (int array) : get cats for given cat ids :return: ids (int array) : integer array of cat ids """ catNms = catNms if _isArrayLike(catNms) else [catNms] supNms = supNms if _isArrayLike(supNms) else [supNms] catIds = catIds if _isArrayLike(catIds) else [catIds] if len(catNms) == len(supNms) == len(catIds) == 0: cats = self.dataset['categories'] else: cats = self.dataset['categories'] cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms] cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms] cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds] ids = [cat['id'] for cat in cats] return ids def getImgIds(self, imgIds=[], catIds=[]): ''' Get img ids that satisfy given filter conditions. :param imgIds (int array) : get imgs for given ids :param catIds (int array) : get imgs with all given cats :return: ids (int array) : integer array of img ids ''' imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] catIds = catIds if _isArrayLike(catIds) else [catIds] if len(imgIds) == len(catIds) == 0: ids = self.imgs.keys() else: ids = set(imgIds) for i, catId in enumerate(catIds): if i == 0 and len(ids) == 0: ids = set(self.catToImgs[catId]) else: ids &= set(self.catToImgs[catId]) return list(ids) def loadAnns(self, ids=[]): """ Load anns with the specified ids. :param ids (int array) : integer ids specifying anns :return: anns (object array) : loaded ann objects """ if _isArrayLike(ids): return [self.anns[id] for id in ids] elif type(ids) == int: return [self.anns[ids]] def loadCats(self, ids=[]): """ Load cats with the specified ids. :param ids (int array) : integer ids specifying cats :return: cats (object array) : loaded cat objects """ if _isArrayLike(ids): return [self.cats[id] for id in ids] elif type(ids) == int: return [self.cats[ids]] def loadImgs(self, ids=[]): """ Load anns with the specified ids. :param ids (int array) : integer ids specifying img :return: imgs (object array) : loaded img objects """ if _isArrayLike(ids): return [self.imgs[id] for id in ids] elif type(ids) == int: return [self.imgs[ids]] def showAnns(self, anns, draw_bbox=False): """ Display the specified annotations. :param anns (array of object): annotations to display :return: None """ if len(anns) == 0: return 0 if 'segmentation' in anns[0] or 'keypoints' in anns[0]: datasetType = 'instances' elif 'caption' in anns[0]: datasetType = 'captions' else: raise Exception('datasetType not supported') if datasetType == 'instances': ax = plt.gca() ax.set_autoscale_on(False) polygons = [] color = [] for ann in anns: c = (np.random.random((1, 3))*0.6+0.4).tolist()[0] if 'segmentation' in ann: if type(ann['segmentation']) == list: # polygon for seg in ann['segmentation']: poly = np.array(seg).reshape((int(len(seg)/2), 2)) polygons.append(Polygon(poly)) color.append(c) else: # mask t = self.imgs[ann['image_id']] if type(ann['segmentation']['counts']) == list: rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width']) else: rle = [ann['segmentation']] m = maskUtils.decode(rle) img = np.ones( (m.shape[0], m.shape[1], 3) ) if ann['iscrowd'] == 1: color_mask = np.array([2.0,166.0,101.0])/255 if ann['iscrowd'] == 0: color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack( (img, m*0.5) )) if 'keypoints' in ann and type(ann['keypoints']) == list: # turn skeleton into zero-based index sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1 kp = np.array(ann['keypoints']) x = kp[0::3] y = kp[1::3] v = kp[2::3] for sk in sks: if np.all(v[sk]>0): plt.plot(x[sk],y[sk], linewidth=3, color=c) plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2) plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2) if draw_bbox: [bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox'] poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]] np_poly = np.array(poly).reshape((4,2)) polygons.append(Polygon(np_poly)) color.append(c) p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4) ax.add_collection(p) p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2) ax.add_collection(p) elif datasetType == 'captions': for ann in anns: print(ann['caption']) def loadRes(self, resFile): """ Load result file and return a result api object. :param resFile (str) : file name of result file :return: res (obj) : result api object """ res = COCO() res.dataset['images'] = [img for img in self.dataset['images']] # print('Loading and preparing results...') tic = time.time() if type(resFile) == str or (PYTHON_VERSION == 2 and type(resFile) == unicode): anns = json.load(open(resFile)) elif type(resFile) == np.ndarray: anns = self.loadNumpyAnnotations(resFile) else: anns = resFile assert type(anns) == list, 'results in not an array of objects' annsImgIds = [ann['image_id'] for ann in anns] assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 'Results do not correspond to current coco set' if 'caption' in anns[0]: imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] for id, ann in enumerate(anns): ann['id'] = id+1 elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) for id, ann in enumerate(anns): bb = ann['bbox'] x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]] if not 'segmentation' in ann: ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] ann['area'] = bb[2]*bb[3] ann['id'] = id+1 ann['iscrowd'] = 0 elif 'segmentation' in anns[0]: res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) for id, ann in enumerate(anns): # now only support compressed RLE format as segmentation results ann['area'] = maskUtils.area(ann['segmentation']) if not 'bbox' in ann: ann['bbox'] = maskUtils.toBbox(ann['segmentation']) ann['id'] = id+1 ann['iscrowd'] = 0 elif 'keypoints' in anns[0]: res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) for id, ann in enumerate(anns): s = ann['keypoints'] x = s[0::3] y = s[1::3] x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y) ann['area'] = (x1-x0)*(y1-y0) ann['id'] = id + 1 ann['bbox'] = [x0,y0,x1-x0,y1-y0] # print('DONE (t={:0.2f}s)'.format(time.time()- tic)) res.dataset['annotations'] = anns res.createIndex() return res def download(self, tarDir = None, imgIds = [] ): ''' Download COCO images from mscoco.org server. :param tarDir (str): COCO results directory name imgIds (list): images to be downloaded :return: ''' if tarDir is None: print('Please specify target directory') return -1 if len(imgIds) == 0: imgs = self.imgs.values() else: imgs = self.loadImgs(imgIds) N = len(imgs) if not os.path.exists(tarDir): os.makedirs(tarDir) for i, img in enumerate(imgs): tic = time.time() fname = os.path.join(tarDir, img['file_name']) if not os.path.exists(fname): urlretrieve(img['coco_url'], fname) print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic)) def loadNumpyAnnotations(self, data): """ Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class} :param data (numpy.ndarray) :return: annotations (python nested list) """ print('Converting ndarray to lists...') assert(type(data) == np.ndarray) print(data.shape) assert(data.shape[1] == 7) N = data.shape[0] ann = [] for i in range(N): if i % 1000000 == 0: print('{}/{}'.format(i,N)) ann += [{ 'image_id' : int(data[i, 0]), 'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ], 'score' : data[i, 5], 'category_id': int(data[i, 6]), }] return ann def annToRLE(self, ann): """ Convert annotation which can be polygons, uncompressed RLE to RLE. :return: binary mask (numpy 2D array) """ t = self.imgs[ann['image_id']] h, w = t['height'], t['width'] segm = ann['segmentation'] if type(segm) == list: # polygon -- a single object might consist of multiple parts # we merge all parts into one mask rle code rles = maskUtils.frPyObjects(segm, h, w) rle = maskUtils.merge(rles) elif type(segm['counts']) == list: # uncompressed RLE rle = maskUtils.frPyObjects(segm, h, w) else: # rle rle = ann['segmentation'] return rle def annToMask(self, ann): """ Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask. :return: binary mask (numpy 2D array) """ rle = self.annToRLE(ann) m = maskUtils.decode(rle) return m ================================================ FILE: src/open-r1-multimodal/src/open_r1/utils/pycocotools/cocoeval.py ================================================ import numpy as np import datetime import time from collections import defaultdict from pycocotools import mask as maskUtils import copy class COCOeval: # Interface for evaluating detection on the Microsoft COCO dataset. # # The usage for CocoEval is as follows: # cocoGt=..., cocoDt=... # load dataset and results # E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object # E.params.recThrs = ...; # set parameters as desired # E.evaluate(); # run per image evaluation # E.accumulate(); # accumulate per image results # E.summarize(); # display summary metrics of results # For example usage see evalDemo.m and http://mscoco.org/. # # The evaluation parameters are as follows (defaults in brackets): # imgIds - [all] N img ids to use for evaluation # catIds - [all] K cat ids to use for evaluation # iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation # recThrs - [0:.01:1] R=101 recall thresholds for evaluation # areaRng - [...] A=4 object area ranges for evaluation # maxDets - [1 10 100] M=3 thresholds on max detections per image # iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints' # iouType replaced the now DEPRECATED useSegm parameter. # useCats - [1] if true use category labels for evaluation # Note: if useCats=0 category labels are ignored as in proposal scoring. # Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified. # # evaluate(): evaluates detections on every image and every category and # concats the results into the "evalImgs" with fields: # dtIds - [1xD] id for each of the D detections (dt) # gtIds - [1xG] id for each of the G ground truths (gt) # dtMatches - [TxD] matching gt id at each IoU or 0 # gtMatches - [TxG] matching dt id at each IoU or 0 # dtScores - [1xD] confidence of each dt # gtIgnore - [1xG] ignore flag for each gt # dtIgnore - [TxD] ignore flag for each dt at each IoU # # accumulate(): accumulates the per-image, per-category evaluation # results in "evalImgs" into the dictionary "eval" with fields: # params - parameters used for evaluation # date - date evaluation was performed # counts - [T,R,K,A,M] parameter dimensions (see above) # precision - [TxRxKxAxM] precision for every evaluation setting # recall - [TxKxAxM] max recall for every evaluation setting # Note: precision and recall==-1 for settings with no gt objects. # # See also coco, mask, pycocoDemo, pycocoEvalDemo # # Microsoft COCO Toolbox. version 2.0 # Data, paper, and tutorials available at: http://mscoco.org/ # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. # Licensed under the Simplified BSD License [see coco/license.txt] def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'): ''' Initialize CocoEval using coco APIs for gt and dt :param cocoGt: coco object with ground truth annotations :param cocoDt: coco object with detection results :return: None ''' if not iouType: print('iouType not specified. use default iouType segm') self.cocoGt = cocoGt # ground truth COCO API self.cocoDt = cocoDt # detections COCO API self.evalImgs = defaultdict(list) # per-image per-category evaluation results [KxAxI] elements self.eval = {} # accumulated evaluation results self._gts = defaultdict(list) # gt for evaluation self._dts = defaultdict(list) # dt for evaluation self.params = Params(iouType=iouType) # parameters self._paramsEval = {} # parameters for evaluation self.stats = [] # result summarization self.ious = {} # ious between all gts and dts if not cocoGt is None: self.params.imgIds = sorted(cocoGt.getImgIds()) self.params.catIds = sorted(cocoGt.getCatIds()) def _prepare(self): ''' Prepare ._gts and ._dts for evaluation based on params :return: None ''' def _toMask(anns, coco): # modify ann['segmentation'] by reference for ann in anns: rle = coco.annToRLE(ann) ann['segmentation'] = rle p = self.params if p.useCats: gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) else: gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) # convert ground truth to mask if iouType == 'segm' if p.iouType == 'segm': _toMask(gts, self.cocoGt) _toMask(dts, self.cocoDt) # set ignore flag for gt in gts: gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0 gt['ignore'] = 'iscrowd' in gt and gt['iscrowd'] if p.iouType == 'keypoints': gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore'] self._gts = defaultdict(list) # gt for evaluation self._dts = defaultdict(list) # dt for evaluation for gt in gts: self._gts[gt['image_id'], gt['category_id']].append(gt) for dt in dts: self._dts[dt['image_id'], dt['category_id']].append(dt) self.evalImgs = defaultdict(list) # per-image per-category evaluation results self.eval = {} # accumulated evaluation results def evaluate(self): ''' Run per image evaluation on given images and store results (a list of dict) in self.evalImgs :return: None ''' tic = time.time() #('Running per image evaluation...') p = self.params # add backward compatibility if useSegm is specified in params if not p.useSegm is None: p.iouType = 'segm' if p.useSegm == 1 else 'bbox' print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) # print('Evaluate annotation type *{}*'.format(p.iouType)) p.imgIds = list(np.unique(p.imgIds)) if p.useCats: p.catIds = list(np.unique(p.catIds)) p.maxDets = sorted(p.maxDets) self.params=p self._prepare() # loop through images, area range, max detection number catIds = p.catIds if p.useCats else [-1] if p.iouType == 'segm' or p.iouType == 'bbox': computeIoU = self.computeIoU elif p.iouType == 'keypoints': computeIoU = self.computeOks self.ious = {(imgId, catId): computeIoU(imgId, catId) \ for imgId in p.imgIds for catId in catIds} evaluateImg = self.evaluateImg maxDet = p.maxDets[-1] self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet) for catId in catIds for areaRng in p.areaRng for imgId in p.imgIds ] self._paramsEval = copy.deepcopy(self.params) toc = time.time() #print('DONE (t={:0.2f}s).'.format(toc-tic)) def computeIoU(self, imgId, catId): p = self.params if p.useCats: gt = self._gts[imgId,catId] dt = self._dts[imgId,catId] else: gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]] dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]] if len(gt) == 0 and len(dt) ==0: return [] inds = np.argsort([-d['score'] for d in dt], kind='mergesort') dt = [dt[i] for i in inds] if len(dt) > p.maxDets[-1]: dt=dt[0:p.maxDets[-1]] if p.iouType == 'segm': g = [g['segmentation'] for g in gt] d = [d['segmentation'] for d in dt] elif p.iouType == 'bbox': g = [g['bbox'] for g in gt] d = [d['bbox'] for d in dt] else: raise Exception('unknown iouType for iou computation') # compute iou between each dt and gt region iscrowd = [int(o['iscrowd']) for o in gt] ious = maskUtils.iou(d,g,iscrowd) return ious def computeOks(self, imgId, catId): p = self.params # dimention here should be Nxm gts = self._gts[imgId, catId] dts = self._dts[imgId, catId] inds = np.argsort([-d['score'] for d in dts], kind='mergesort') dts = [dts[i] for i in inds] if len(dts) > p.maxDets[-1]: dts = dts[0:p.maxDets[-1]] # if len(gts) == 0 and len(dts) == 0: if len(gts) == 0 or len(dts) == 0: return [] ious = np.zeros((len(dts), len(gts))) sigmas = p.kpt_oks_sigmas vars = (sigmas * 2)**2 k = len(sigmas) # compute oks between each detection and ground truth object for j, gt in enumerate(gts): # create bounds for ignore regions(double the gt bbox) g = np.array(gt['keypoints']) xg = g[0::3]; yg = g[1::3]; vg = g[2::3] k1 = np.count_nonzero(vg > 0) bb = gt['bbox'] x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2 y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2 for i, dt in enumerate(dts): d = np.array(dt['keypoints']) xd = d[0::3]; yd = d[1::3] if k1>0: # measure the per-keypoint distance if keypoints visible dx = xd - xg dy = yd - yg else: # measure minimum distance to keypoints in (x0,y0) & (x1,y1) z = np.zeros((k)) dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0) dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0) e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2 if k1 > 0: e=e[vg > 0] ious[i, j] = np.sum(np.exp(-e)) / e.shape[0] return ious def evaluateImg(self, imgId, catId, aRng, maxDet): ''' perform evaluation for single category and image :return: dict (single image results) ''' p = self.params if p.useCats: gt = self._gts[imgId,catId] dt = self._dts[imgId,catId] else: gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]] dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]] if len(gt) == 0 and len(dt) ==0: return None for g in gt: if g['ignore'] or (g['area']aRng[1]): g['_ignore'] = 1 else: g['_ignore'] = 0 # sort dt highest score first, sort gt ignore last gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort') gt = [gt[i] for i in gtind] dtind = np.argsort([-d['score'] for d in dt], kind='mergesort') dt = [dt[i] for i in dtind[0:maxDet]] iscrowd = [int(o['iscrowd']) for o in gt] # load computed ious ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId] T = len(p.iouThrs) G = len(gt) D = len(dt) gtm = np.zeros((T,G)) dtm = np.zeros((T,D)) gtIg = np.array([g['_ignore'] for g in gt]) dtIg = np.zeros((T,D)) if not len(ious)==0: for tind, t in enumerate(p.iouThrs): for dind, d in enumerate(dt): # information about best match so far (m=-1 -> unmatched) iou = min([t,1-1e-10]) m = -1 for gind, g in enumerate(gt): # if this gt already matched, and not a crowd, continue if gtm[tind,gind]>0 and not iscrowd[gind]: continue # if dt matched to reg gt, and on ignore gt, stop if m>-1 and gtIg[m]==0 and gtIg[gind]==1: break # continue to next gt unless better match made if ious[dind,gind] < iou: continue # if match successful and best so far, store appropriately iou=ious[dind,gind] m=gind # if match made store id of match for both dt and gt if m ==-1: continue dtIg[tind,dind] = gtIg[m] dtm[tind,dind] = gt[m]['id'] gtm[tind,m] = d['id'] # set unmatched detections outside of area range to ignore a = np.array([d['area']aRng[1] for d in dt]).reshape((1, len(dt))) dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0))) # store results for given image and category return { 'image_id': imgId, 'category_id': catId, 'aRng': aRng, 'maxDet': maxDet, 'dtIds': [d['id'] for d in dt], 'gtIds': [g['id'] for g in gt], 'dtMatches': dtm, 'gtMatches': gtm, 'dtScores': [d['score'] for d in dt], 'gtIgnore': gtIg, 'dtIgnore': dtIg, } def accumulate(self, p = None): ''' Accumulate per image evaluation results and store the result in self.eval :param p: input params for evaluation :return: None ''' #print('Accumulating evaluation results...') tic = time.time() if not self.evalImgs: print('Please run evaluate() first') # allows input customized parameters if p is None: p = self.params p.catIds = p.catIds if p.useCats == 1 else [-1] T = len(p.iouThrs) R = len(p.recThrs) K = len(p.catIds) if p.useCats else 1 A = len(p.areaRng) M = len(p.maxDets) precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories recall = -np.ones((T,K,A,M)) scores = -np.ones((T,R,K,A,M)) # create dictionary for future indexing _pe = self._paramsEval catIds = _pe.catIds if _pe.useCats else [-1] setK = set(catIds) setA = set(map(tuple, _pe.areaRng)) setM = set(_pe.maxDets) setI = set(_pe.imgIds) # get inds to evaluate k_list = [n for n, k in enumerate(p.catIds) if k in setK] m_list = [m for n, m in enumerate(p.maxDets) if m in setM] a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA] i_list = [n for n, i in enumerate(p.imgIds) if i in setI] I0 = len(_pe.imgIds) A0 = len(_pe.areaRng) # retrieve E at each category, area range, and max number of detections for k, k0 in enumerate(k_list): Nk = k0*A0*I0 for a, a0 in enumerate(a_list): Na = a0*I0 for m, maxDet in enumerate(m_list): E = [self.evalImgs[Nk + Na + i] for i in i_list] E = [e for e in E if not e is None] if len(E) == 0: continue dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E]) # different sorting method generates slightly different results. # mergesort is used to be consistent as Matlab implementation. inds = np.argsort(-dtScores, kind='mergesort') dtScoresSorted = dtScores[inds] dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds] dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds] gtIg = np.concatenate([e['gtIgnore'] for e in E]) npig = np.count_nonzero(gtIg==0 ) if npig == 0: continue tps = np.logical_and( dtm, np.logical_not(dtIg) ) fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) ) tp_sum = np.cumsum(tps, axis=1).astype(dtype=float) fp_sum = np.cumsum(fps, axis=1).astype(dtype=float) for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): tp = np.array(tp) fp = np.array(fp) nd = len(tp) rc = tp / npig pr = tp / (fp+tp+np.spacing(1)) q = np.zeros((R,)) ss = np.zeros((R,)) if nd: recall[t,k,a,m] = rc[-1] else: recall[t,k,a,m] = 0 # numpy is slow without cython optimization for accessing elements # use python array gets significant speed improvement pr = pr.tolist(); q = q.tolist() for i in range(nd-1, 0, -1): if pr[i] > pr[i-1]: pr[i-1] = pr[i] inds = np.searchsorted(rc, p.recThrs, side='left') try: for ri, pi in enumerate(inds): q[ri] = pr[pi] ss[ri] = dtScoresSorted[pi] except: pass precision[t,:,k,a,m] = np.array(q) scores[t,:,k,a,m] = np.array(ss) self.eval = { 'params': p, 'counts': [T, R, K, A, M], 'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'precision': precision, 'recall': recall, 'scores': scores, } toc = time.time() # print('DONE (t={:0.2f}s).'.format( toc-tic)) def summarize(self): ''' Compute and display summary metrics for evaluation results. Note this functin can *only* be applied on the default parameter setting ''' def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ): p = self.params iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}' titleStr = 'Average Precision' if ap == 1 else 'Average Recall' typeStr = '(AP)' if ap==1 else '(AR)' iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \ if iouThr is None else '{:0.2f}'.format(iouThr) aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] if ap == 1: # dimension of precision: [TxRxKxAxM] s = self.eval['precision'] # IoU if iouThr is not None: t = np.where(iouThr == p.iouThrs)[0] s = s[t] s = s[:,:,:,aind,mind] else: # dimension of recall: [TxKxAxM] s = self.eval['recall'] if iouThr is not None: t = np.where(iouThr == p.iouThrs)[0] s = s[t] s = s[:,:,aind,mind] if len(s[s>-1])==0: mean_s = -1 else: mean_s = np.mean(s[s>-1]) #print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)) return mean_s def _summarizeDets(): stats = np.zeros((12,)) stats[0] = _summarize(1) stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2]) stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2]) stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2]) stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2]) stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2]) stats[6] = _summarize(0, maxDets=self.params.maxDets[0]) stats[7] = _summarize(0, maxDets=self.params.maxDets[1]) stats[8] = _summarize(0, maxDets=self.params.maxDets[2]) stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2]) stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2]) stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2]) return stats def _summarizeKps(): stats = np.zeros((10,)) stats[0] = _summarize(1, maxDets=20) stats[1] = _summarize(1, maxDets=20, iouThr=.5) stats[2] = _summarize(1, maxDets=20, iouThr=.75) stats[3] = _summarize(1, maxDets=20, areaRng='medium') stats[4] = _summarize(1, maxDets=20, areaRng='large') stats[5] = _summarize(0, maxDets=20) stats[6] = _summarize(0, maxDets=20, iouThr=.5) stats[7] = _summarize(0, maxDets=20, iouThr=.75) stats[8] = _summarize(0, maxDets=20, areaRng='medium') stats[9] = _summarize(0, maxDets=20, areaRng='large') return stats if not self.eval: raise Exception('Please run accumulate() first') iouType = self.params.iouType if iouType == 'segm' or iouType == 'bbox': summarize = _summarizeDets elif iouType == 'keypoints': summarize = _summarizeKps self.stats = summarize() def __str__(self): self.summarize() class Params: ''' Params for coco evaluation api ''' def setDetParams(self): self.imgIds = [] self.catIds = [] # np.arange causes trouble. the data point on arange is slightly larger than the true value self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True) self.maxDets = [1, 10, 100] self.areaRng = [[0 ** 2, 1e5 ** 2], [0 ** 2, 32 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]] self.areaRngLbl = ['all', 'small', 'medium', 'large'] self.useCats = 1 def setKpParams(self): self.imgIds = [] self.catIds = [] # np.arange causes trouble. the data point on arange is slightly larger than the true value self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True) self.maxDets = [20] self.areaRng = [[0 ** 2, 1e5 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]] self.areaRngLbl = ['all', 'medium', 'large'] self.useCats = 1 self.kpt_oks_sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0 def __init__(self, iouType='segm'): if iouType == 'segm' or iouType == 'bbox': self.setDetParams() elif iouType == 'keypoints': self.setKpParams() else: raise Exception('iouType not supported') self.iouType = iouType # useSegm is deprecated self.useSegm = None ================================================ FILE: src/open-r1-multimodal/src/open_r1/vlm_modules/__init__.py ================================================ from .vlm_module import VLMBaseModule from .qwen_module import Qwen2VLModule __all__ = ["VLMBaseModule", "Qwen2VLModule"] ================================================ FILE: src/open-r1-multimodal/src/open_r1/vlm_modules/qwen_module.py ================================================ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor from typing import Dict, Any, Union from trl.data_utils import maybe_apply_chat_template import torch import re from transformers import AutoTokenizer from vlm_modules.vlm_module import VLMBaseModule import math import numpy as np class Qwen2VLModule(VLMBaseModule): def __init__(self): super().__init__() def get_vlm_key(self): return "qwen" def get_model_class(self, model_id: str, model_init_kwargs: dict): if "Qwen2-VL" in model_id: model_cls = Qwen2VLForConditionalGeneration elif "Qwen2.5-VL" in model_id: model_cls = Qwen2_5_VLForConditionalGeneration else: raise ValueError(f"Unsupported model: {model_id}") return model_cls def post_model_init(self, model, processing_class): pass def get_processing_class(self): return AutoProcessor def get_vision_modules_keywords(self): return ['visual'] def get_custom_multimodal_keywords(self): return ['pixel_values', 'image_grid_thw'] def get_non_generate_params(self): return [] def get_custom_processing_keywords(self): return [('image_processor', 'max_pixels'), ('image_processor', 'min_pixels')] def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]): prompts_text = [maybe_apply_chat_template(example, processing_class)["prompt"] for example in inputs] return prompts_text def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False): # FIXME # print(type(prompts_text)) # This could only process pure-multimodal or pure-text inputs if len(images) > 0: prompt_inputs = processing_class( text=prompts_text, images=images, return_tensors=return_tensors, padding=padding, padding_side=padding_side, add_special_tokens=add_special_tokens) else: prompt_inputs = processing_class( text=prompts_text, return_tensors=return_tensors, padding=padding, padding_side=padding_side, add_special_tokens=add_special_tokens) return prompt_inputs @staticmethod def get_question_template(task_type: str): match task_type: case "robust": return "{Question}First output the types of degradations in image briefly in tags, and then output what effects do these degradation have on the image in tags, then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in tags, and then summarize the content of reasoning and the give the answer in tags,provides the user with the answer briefly in .i.e., degradation type here \n influence here\n reasoning process here\nsummary here\nfinal answer" case "rec": return "{Question} First output the thinking process in tags and then output the final answer in tags. Output the final answer in JSON format." case "ic": return "{Question} First thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here json format answer here " case "odLength": SYSTEM_PROMPT = ( #"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " "First thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " "process and answer are enclosed within and tags, respectively, i.e., " " reasoning process here answer here " ) return SYSTEM_PROMPT + '\n' + "{Question}" case _: return "{Question} First output the thinking process in tags and then output the final answer in tags." @staticmethod def format_reward_rec(completions, **kwargs): """Check if the Qwen model output matches a specific format.""" import re import os from datetime import datetime pattern = r".*?\s*.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?" completion_contents = [completion[0]["content"] for completion in completions] matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents] current_time = datetime.now().strftime("%d-%H-%M-%S-%f") if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Format reward -------------\n") for content, match in zip(completion_contents, matches): f.write(f"Content: {content}\n") f.write(f"Has format: {bool(match)}\n") return [1.0 if match else 0.0 for match in matches] @staticmethod def format_reward_robust(completions, **kwargs): import re import os from datetime import datetime pattern = r".*?\s*.*?\s*.*?\s*.*?\s*.*?" completion_contents = [completion[0]["content"] for completion in completions] matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents] current_time = datetime.now().strftime("%d-%H-%M-%S-%f") if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Format reward -------------\n") for content, match in zip(completion_contents, matches): f.write(f"Content: {content}\n") f.write(f"Has format: {bool(match)}\n") return [1.0 if match else 0.0 for match in matches] @staticmethod def type_reward(completions, solution, **kwargs): def custom_normalize_reward(score, k_positive=1.0, k_negative=2.0, x0=0.0): sigmoid_output = 0.0 if score >= x0: sigmoid_output = 1 / (1 + math.exp(-k_positive * (score - x0))) else: sigmoid_output = 1 / (1 + math.exp(-k_negative * (score - x0))) normalized_score = 2 * sigmoid_output - 1 return normalized_score def extract_image_degradations(text): match = re.search(r'(.*?)', text, re.DOTALL) if not match: return [] types_string = match.group(1) degradations = re.findall(r'(\w+(?:\s+\w+)*)\(([\d.]+)\)', types_string) result = [] for degradation, strength in degradations: result.append((degradation.strip(), float(strength))) return result def calculate_reward(A, B): reward = 0.0 B_dict = dict(B) matched_keys = set() for degradation_A, strength_A in A: if degradation_A in B_dict: reward += 1 strength_B = B_dict[degradation_A] diff = abs(strength_A - strength_B) reward += (0.5 - diff) matched_keys.add(degradation_A) else: reward -= 1 for degradation_B in B_dict: if degradation_B not in matched_keys: reward -= 1 return reward contents = [completion[0]["content"] for completion in completions] rewards = [] for i in range(len(contents)): content_single = extract_image_degradations(contents[i]) solution_single = extract_image_degradations(solution[i]) score = calculate_reward(content_single, solution_single) rewards.append(score) return rewards @staticmethod def accuracy_reward(completions, solution, **kwargs): def extract_answer(text): match = re.search(r'(.*?)', text, re.DOTALL) if match: return match.group(1).strip() return None contents = [completion[0]["content"] for completion in completions] if len(contents) != len(solution): print("Warning: Input list lengths do not match.") return [] rewards = [] for i in range(len(contents)): model_answer = extract_answer(contents[i]) gt_answer = extract_answer(solution[i]) if model_answer == gt_answer: rewards.append(1) else: rewards.append(0) return rewards @staticmethod def length_reward(completions, solution, **kwargs): processor = AutoProcessor.from_pretrained("your_model_path",user_fast=False) tokenizer =processor.tokenizer responses = [completion[0]["content"] for completion in completions] if len(responses) != len(solution): print("Warning: Input list lengths do not match.") return [] rewards = [] for resp, sol in zip(responses, solution): resp_len = len(tokenizer.encode(resp)) sol_len = len(tokenizer.encode(sol)) length_diff = abs(resp_len - sol_len) reward = 1 - (length_diff/sol_len) rewards.append(reward) return rewards @staticmethod def iou_reward(completions, solution, **kwargs): import re import os from datetime import datetime import json def iou(box1, box2): inter_x1 = max(box1[0], box2[0]) inter_y1 = max(box1[1], box2[1]) inter_x2 = min(box1[2]-1, box2[2]-1) inter_y2 = min(box1[3]-1, box2[3]-1) if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) else: inter = 0 union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter return float(inter)/union contents = [completion[0]["content"] for completion in completions] rewards = [] current_time = datetime.now().strftime("%d-%H-%M-%S-%f") answer_tag_pattern = r'(.*?)' bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]' for content, sol in zip(contents, solution): sol = re.findall(answer_tag_pattern, sol, re.DOTALL)[-1] sol = json.loads(sol.strip()) reward = 0.0 try: content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL) if content_answer_match: content_answer = content_answer_match.group(1).strip() bbox_match = re.search(bbox_pattern, content_answer) if bbox_match: bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))] reward = iou(bbox, sol) except Exception: pass rewards.append(reward) if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") current_time = datetime.now().strftime("%d-%H-%M-%S-%f") image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None problem = kwargs.get("problem")[0] if reward <= 1.0: with open(log_path, "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n") f.write(f"image_path: {image_path}\n") f.write(f"problem: {problem}\n") f.write(f"Content: {content}\n") f.write(f"Solution: {sol}\n") return rewards @staticmethod def select_reward_func(func: str, task_type: str): if func == "accuracy": match task_type: case "robust": return Qwen2VLModule.accuracy_reward case "rec": return Qwen2VLModule.iou_reward case _: raise ValueError(f"Unsupported reward function: {func}") elif func == "format": match task_type: case "robust": return Qwen2VLModule.format_reward_robust case "rec": return Qwen2VLModule.format_reward_rec case _: raise ValueError(f"Unsupported reward function: {func}") elif func == "type": match task_type: case "robust": return Qwen2VLModule.type_reward case "rec": return Qwen2VLModule.format_reward_rec case _: raise ValueError(f"Unsupported reward function: {func}") elif func == "length": match task_type: case "robust": return Qwen2VLModule.length_reward case "rec": return Qwen2VLModule.format_reward_rec case _: raise ValueError(f"Unsupported reward function: {func}") else: raise ValueError(f"Unsupported reward function: {func}") ================================================ FILE: src/open-r1-multimodal/src/open_r1/vlm_modules/vlm_module.py ================================================ from abc import ABC, abstractmethod from typing import Dict, Any, Union import torch class VLMBaseModule(ABC): def __init__(self): super().__init__() @abstractmethod def get_vlm_key(self): pass @abstractmethod def get_model_class(self, model_id: str, model_init_kwargs: dict): pass def post_model_init(self, model, processing_class): pass def is_embeds_input(self): return False @abstractmethod def get_processing_class(self): pass @abstractmethod def get_vision_modules_keywords(self): pass @abstractmethod def get_custom_multimodal_keywords(self): pass @abstractmethod def get_non_generate_params(self): pass @abstractmethod def get_custom_processing_keywords(self): pass @abstractmethod def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]): pass @abstractmethod def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors, padding, padding_side, add_special_tokens): pass