Repository: brendanhogan/DeepSeekRL-Extended Branch: main Commit: fb3190a65094 Files: 11 Total size: 59.1 KB Directory structure: gitextract_2gjl86h1/ ├── .gitignore ├── LICENSE ├── README.md ├── evaluator.py ├── llms.py ├── main.py ├── plotter.py ├── requirements.txt ├── rldatasets.py ├── run.sh └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # UV # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. #uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ # PyPI configuration file .pypirc ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2025 Brendan Hogan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # DeepSeek R1 Implementation ## Motivation I wanted to recreate DeepSeek R1's results at a smaller scale, focusing on understanding the core mechanics by implementing everything from scratch. So this is a repo that trains Qwen1.5B on the [grade school math dataset](https://github.com/openai/grade-school-math). This implementation heavily borrows from [Will Brown's work](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) ([@willccbb](https://x.com/willccbb)), but restructures the code into a format optimized for learning and experimentation. The key difference in my implementation is computing the GRPO loss function directly rather than using external RL libraries, and reformatting into a multi script repo. I hope this might help other people understand things better, and maybe provide an easier way to try out smaller scale ideas etc. ## Installation ``` pip install -r requirements.txt ``` Required environment variables: ``` export HUGGINGFACE_TOKEN="your-token-here" huggingface-cli login ``` ## Implementation Details The system consists of several key modules: ### main.py Contains the core training loop implementing GRPO (Generalized Reward-Powered Optimization). Handles model training, evaluation, and metric tracking. ### llms.py Manages model loading and configuration, currently supporting LLaMA + Qwen models through Hugging Face's transformers library. Designed to be easily extensible to other model architectures. ### rldatasets.py Handles dataset loading and preprocessing, currently focused on GSM8K math problems. Implements custom data loaders for both training and evaluation. ### evaluator.py Contains evaluation metrics and reward functions, closely following DeepSeek's original implementation. ## Results Training was conducted on a single H100 GPU. After ~400 training steps: ![Training Results](plots/train_score.png) And results on the validation set - this shows a clearer sign of learning: ![Eval Results](plots/eval_score.png) ## Future Directions I'm really pleased to see how well the key mechanics work even in this simplified implementation. Building on this, I am very excited about several directions: 1. Adding self-play capabilities where agents compete and learn from each other using relative rewards. This would create a more dynamic training environment where the reward signal comes from agent interactions rather than fixed metrics. 2. Implementing soft reward structures, particularly for complex reasoning tasks. I've writing a framework for AI debate that I'm excited to try out. 3. Expanding into vision-language models (VLMs) to improve world modeling capabilities. I have an idea about using R1-style training to enhance how VLMs build and maintain internal world models that I'm really excited to explore. (Really excited about this idea - if anyone else is interested I would love to talk.) 4. I'd like to do all this experimentation in this framework, so I need to make things faster, and support multi-gpu training. ================================================ FILE: evaluator.py ================================================ """ Abstract base class and implementations for reward computation in RL training. """ import re import torch from abc import ABC, abstractmethod from typing import List, Dict, Tuple, Any class RewardEvaluator(ABC): """ Abstract base class for reward computation in RL training. This class defines the interface for reward evaluators that can be used to score model completions during RL training. Implement this class to create custom reward functions for different tasks. The main methods that need to be implemented are: - compute_rewards: Computes rewards for a batch of completions - get_reward_breakdown: Converts raw reward scores to a labeled dictionary """ @abstractmethod def compute_rewards( self, prompts: List[List[Dict[str, str]]], completions: List[List[Dict[str, str]]], answer: Any, device: str ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute rewards for a batch of completions. Args: prompts: List of prompt messages in chat format [{"role": "user", "content": "..."}, ...] completions: List of completion messages in chat format [{"role": "assistant", "content": "..."}, ...] answer: Ground truth answer(s) for the prompts device: Device to place tensors on ("cpu" or "cuda") Returns: rewards_per_func: Tensor of shape (num_completions, num_reward_functions) containing individual reward function scores metrics: Dictionary of aggregated metrics including mean rewards per function and total reward """ pass @abstractmethod def get_reward_breakdown(self, reward_scores: torch.Tensor) -> Dict[str, float]: """ Convert raw reward scores tensor to a labeled dictionary. Args: reward_scores: Tensor of raw scores from compute_rewards Returns: Dictionary mapping reward function names to their scores """ pass def get_evaluator(name: str) -> RewardEvaluator: """ Get the appropriate reward evaluator for a given task. Args: name: Name of the task/dataset to get evaluator for Returns: RewardEvaluator instance for the specified task Raises: NotImplementedError: If evaluator for given task is not implemented """ if name.lower() == "gsm8k": return GSM8kEvaluator() else: raise NotImplementedError(f"No evaluator implemented for {name}") class GSM8kEvaluator(RewardEvaluator): """ Reward evaluator for the GSM8K math problem dataset. Implements reward functions for: - Answer correctness - Integer format validation - XML formatting (strict and soft) - XML tag counting """ def __init__(self): self.num_reward_functions = 5 def _extract_xml_answer(self, text: str) -> str: """Extract answer from XML tags.""" answer = text.split("")[-1] answer = answer.split("")[0] return answer.strip() def _correctness_reward(self, prompts, completions, answer) -> List[float]: """Reward for correct answer.""" responses = [completion[0]['content'] for completion in completions] extracted = [self._extract_xml_answer(r) for r in responses] return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)] def _int_format_reward(self, completions) -> List[float]: """Reward for integer format.""" responses = [completion[0]['content'] for completion in completions] extracted = [self._extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted] def _strict_format_reward(self, completions) -> List[float]: """Reward for strict XML format.""" pattern = r"^\n.*?\n\n\n.*?\n\n$" responses = [completion[0]["content"] for completion in completions] matches = [bool(re.match(pattern, r)) for r in responses] return [0.5 if m else 0.0 for m in matches] def _soft_format_reward(self, completions) -> List[float]: """Reward for relaxed XML format.""" pattern = r".*?\s*.*?" responses = [completion[0]["content"] for completion in completions] matches = [bool(re.match(pattern, r)) for r in responses] return [0.5 if m else 0.0 for m in matches] def _xml_count_reward(self, completions) -> List[float]: """Reward for XML tag counting.""" def count_xml(text: str) -> float: count = 0.0 if text.count("\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 count -= len(text.split("\n\n")[-1])*0.001 if text.count("\n") == 1: count += 0.125 count -= (len(text.split("\n")[-1]) - 1)*0.001 return count responses = [completion[0]["content"] for completion in completions] return [count_xml(r) for r in responses] def compute_rewards( self, prompts: List[List[Dict[str, str]]], completions: List[List[Dict[str, str]]], answer: Any, device: str ) -> Tuple[torch.Tensor, Dict[str, float]]: """Compute all rewards for the given completions.""" num_completions = len(completions) rewards_per_func = torch.zeros(num_completions, self.num_reward_functions, device=device) # Compute all reward functions all_scores = [ self._correctness_reward(prompts, completions, answer), self._int_format_reward(completions), self._strict_format_reward(completions), self._soft_format_reward(completions), self._xml_count_reward(completions) ] # Fill rewards tensor for i, scores in enumerate(all_scores): rewards_per_func[:, i] = torch.tensor(scores, dtype=torch.float32, device=device) # Compute metrics reward_per_func = rewards_per_func.mean(0) # Calculate accuracy (perfect correctness score) correctness_scores = rewards_per_func[:, 0] # First reward function is correctness num_perfect = (correctness_scores == 2.0).sum().item() accuracy = num_perfect / num_completions metrics = { "rewards/correctness_reward_func": reward_per_func[0].item(), "rewards/int_reward_func": reward_per_func[1].item(), "rewards/strict_format_reward_func": reward_per_func[2].item(), "rewards/soft_format_reward_func": reward_per_func[3].item(), "rewards/xmlcount_reward_func": reward_per_func[4].item(), "reward": rewards_per_func.sum(dim=1).mean().item(), "accuracy": accuracy } return rewards_per_func, metrics def get_reward_breakdown(self, reward_scores: torch.Tensor) -> Dict[str, float]: """Convert reward scores tensor to labeled dictionary.""" return { 'correctness': reward_scores[0].item(), 'integer_format': reward_scores[1].item(), 'strict_format': reward_scores[2].item(), 'soft_format': reward_scores[3].item(), 'xml_count': reward_scores[4].item() } ================================================ FILE: llms.py ================================================ """ Module for loading LLMs and their tokenizers from huggingface. """ import torch from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase def get_llm_tokenizer(model_name: str, device: str) -> tuple[PreTrainedModel, PreTrainedTokenizerBase]: """ Load and configure a language model and its tokenizer. Args: model_name: Name or path of the pretrained model to load device: Device to load the model on ('cpu' or 'cuda') Returns: tuple containing: - The loaded language model - The configured tokenizer for that model """ model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map=None, ).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = tokenizer.pad_token_id model.config.use_cache = False return model, tokenizer ================================================ FILE: main.py ================================================ """ Implementation of GRPO, DeepSeek style training without external libraries """ import os import json import torch import argparse from tqdm import tqdm from collections import defaultdict from transformers import PreTrainedModel, PreTrainedTokenizerBase, GenerationConfig import llms import utils import evaluator import rldatasets def eval_on_test_set( model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, test_loader: rldatasets.DataLoader, eval_class: evaluator.RewardEvaluator, device: str, args: argparse.Namespace, round_num: int ) -> tuple[dict[str, float], float]: """ Evaluate model performance on test set. Args: model: The model to evaluate tokenizer: Tokenizer for the model test_loader: DataLoader for test set eval_class: Evaluator for computing rewards device: Device to run on args: Training arguments round_num: Current training round number Returns: total_scores: Dictionary of average metrics accuracy: Accuracy on test set """ print("Running evaluation on test set...") # Track metrics across all test examples total_scores = defaultdict(float) num_examples = 0 total_accuracy = 0.0 # Create log file for this evaluation round log_file = os.path.join(args.output_dir, f'eval_metrics_{round_num}.txt') test_loader.reset() with open(log_file, 'w') as f: # Run through test set for question, answer in tqdm(test_loader, desc="Evaluating on test set"): # Generate completions using same function as training _, _, _, _, completions_text, _ = generate_completions( model, tokenizer, question, device, args ) # Score completions using evaluator mock_prompts = [[{'content': question}]] * len(completions_text) mock_completions = [[{'content': completion}] for completion in completions_text] # Make answer array same length as completions answers = [answer] * len(completions_text) rewards_per_func, metrics = eval_class.compute_rewards( prompts=mock_prompts, completions=mock_completions, answer=answers, device=device ) # Track accuracy and accumulate metrics total_accuracy += metrics['accuracy'] for k, v in metrics.items(): total_scores[k] += v num_examples += 1 # Log this example f.write("\n" + "="*50 + "\n") f.write(f"Q# {num_examples}\n") f.write(f"Question: {question}\n") f.write(f"Response: {completions_text[0]}\n") # Log first completion f.write(f"Ground Truth: {answer}\n") f.write("Metrics:\n") for metric, value in metrics.items(): f.write(f"{metric}: {value}\n") f.write(f"Total Score: {rewards_per_func.sum().item()}\n") # Calculate averages avg_scores = {k: v/num_examples for k,v in total_scores.items()} accuracy = total_accuracy / num_examples * 100 # Save metrics metrics_path = os.path.join(args.output_dir, f'eval_metrics_{round_num}.json') with open(metrics_path, 'w') as f: json.dump({**avg_scores, 'accuracy': accuracy}, f, indent=4) if args.verbose: print("\nEvaluation Results:") print("-" * 20) print(f"Accuracy: {accuracy:.2f}%") for metric, value in avg_scores.items(): print(f"{metric:15s}: {value:.4f}") print("-" * 20) return avg_scores, accuracy def generate_completions( model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, question: str, device: str, args: argparse.Namespace ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], str]: """ Generate multiple completion sequences for a given prompt using a language model. Args: model: The language model to use for generation tokenizer: Tokenizer corresponding to the model question: The input question/prompt to generate completions for device: Device to run generation on ('cpu' or 'cuda') args: Namespace containing generation parameters Returns: prompt_completion_ids: Tensor containing the full sequence of prompt + completion token IDs prompt_ids: Tensor containing just the prompt token IDs completion_ids: Tensor containing just the completion token IDs attention_mask: Attention mask tensor for the full sequence completions_text: List of decoded completion texts prompt_text: The full formatted prompt text """ # 1. Prepare prompting prompt = [ {'role': 'system', 'content': train_loader.system_prompt}, {'role': 'user', 'content': question} ] prompt_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) prompt_inputs = tokenizer(prompt_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] # Truncate prompt to max length and repeat for number of generations prompt_ids = prompt_ids[:, -args.max_prompt_length:] prompt_mask = prompt_mask[:, -args.max_prompt_length:] # Repeat for number of chains/generations prompt_ids = prompt_ids.repeat(args.num_chains, 1) prompt_mask = prompt_mask.repeat(args.num_chains, 1) # Move tensors to device prompt_ids = prompt_ids.to(device) prompt_mask = prompt_mask.to(device) # Set up generation config generation_config = GenerationConfig( max_new_tokens=args.max_completion_length, do_sample=True, temperature=args.temperature, pad_token_id=tokenizer.pad_token_id ) # Generate completions prompt_completion_ids = model.generate( prompt_ids, attention_mask=prompt_mask, generation_config=generation_config ) # Extract completion ids prompt_length = prompt_ids.size(1) prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] # Do masking is_eos = completion_ids == tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # Decode completions completions_text = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) return prompt_completion_ids, prompt_ids, completion_ids, attention_mask, completions_text, prompt_text def score_completions( completions_text: list[str], question: str, answer: str, eval_class: evaluator.RewardEvaluator, device: str, args: argparse.Namespace ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float], dict]: """ Score model completions and compute advantages for training. Args: completions_text: List of generated completion strings question: Original input question/prompt answer: Ground truth answer eval_class: Evaluator class for computing rewards device: Device to place tensors on args: Training arguments Returns: rewards: Raw reward scores for each completion advantages: Computed advantages for policy gradient rewards_per_func: Rewards broken down by individual reward functions metrics: Dictionary of aggregated metrics log_data: Dictionary containing detailed generation and scoring data """ # Build log data dictionary log_data = { 'prompt': { 'text': question, 'answer': answer }, 'generations': [] } # Format inputs as expected by evaluator mock_prompts = [[{'content': question}]] * len(completions_text) mock_completions = [[{'content': completion}] for completion in completions_text] answers = [answer] * len(completions_text) # Get rewards and metrics from evaluator rewards_per_func, metrics = eval_class.compute_rewards( prompts=mock_prompts, completions=mock_completions, answer=answers, device=device ) rewards = rewards_per_func.sum(dim=1) # Store generation data for i, (completion, reward_scores) in enumerate(zip(completions_text, rewards_per_func)): generation_data = { 'response': completion, 'scores': { **eval_class.get_reward_breakdown(reward_scores), 'total_reward': rewards[i].item() } } log_data['generations'].append(generation_data) # Compute advantages mean_grouped_rewards = rewards.view(-1, args.num_chains).mean(dim=1) std_grouped_rewards = rewards.view(-1, args.num_chains).std(dim=1) mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(args.num_chains, dim=0) std_grouped_rewards = std_grouped_rewards.repeat_interleave(args.num_chains, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) metrics["reward_std"] = std_grouped_rewards.mean().item() # Store summary statistics log_data['summary_stats'] = { 'mean_rewards_per_group': mean_grouped_rewards.tolist(), 'std_rewards_per_group': std_grouped_rewards.tolist(), 'advantages': advantages.tolist() } return rewards, advantages, rewards_per_func, metrics, log_data def compute_loss( model: PreTrainedModel, base_model: PreTrainedModel, prompt_completion_ids: torch.Tensor, prompt_ids: torch.Tensor, completion_ids: torch.Tensor, attention_mask: torch.Tensor, completion_mask: torch.Tensor, advantages: torch.Tensor, args: argparse.Namespace ) -> tuple[torch.Tensor, dict[str, float]]: """ Compute the GRPO loss between current and base model. Args: model: The current model being trained base_model: The reference model to compare against prompt_completion_ids: Combined prompt and completion token IDs prompt_ids: Token IDs for just the prompt completion_ids: Token IDs for just the completion attention_mask: Attention mask for the full sequence completion_mask: Mask indicating which tokens are from the completion advantages: Advantage values for each sequence args: Training arguments Returns: loss: The computed GRPO loss metrics: Dictionary containing additional metrics like KL divergence """ # Only need the generated tokens' logits logits_to_keep = completion_ids.size(1) # Get reference model logits with torch.inference_mode(): ref_per_token_logps = utils.get_per_token_logps(base_model, prompt_completion_ids, attention_mask, logits_to_keep) # Get training model logits input_ids = torch.cat([prompt_ids, completion_ids], dim=1) per_token_logps = utils.get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute KL divergence per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 # Compute loss with advantages per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) per_token_loss = -(per_token_loss - args.kl_weight_beta * per_token_kl) loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Additional metrics metrics = {} response_length = completion_mask.sum(1).float().mean().item() metrics["response_length"] = response_length mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() metrics["kl"] = mean_kl.item() return loss, metrics def grpo_loss( model: PreTrainedModel, base_model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, question: str, answer: str, eval_class: evaluator.RewardEvaluator, device: str, round_num: int, training_log_dir: str, args: argparse.Namespace ) -> tuple[torch.Tensor, dict[str, float], float]: """ Compute GRPO loss between the current model and base model. Args: model: The current model being trained base_model: The reference model to compare against tokenizer: Tokenizer for the models question: Input question/prompt answer: Ground truth answer eval_class: Evaluator for computing rewards device: Device to run on ('cpu' or 'cuda') round_num: Current training round number training_log_dir: Directory to save training logs args: Training arguments Returns: loss: The computed GRPO loss metrics: Dictionary containing training metrics reward: The total reward for this batch """ # Generate completions prompt_completion_ids, prompt_ids, completion_ids, attention_mask, completions_text, prompt_text = generate_completions( model, tokenizer, question, device, args ) # Score completions rewards, advantages, rewards_per_func, metrics, log_data = score_completions( completions_text, question, answer, eval_class, device, args ) # Write log data log_file = os.path.join(training_log_dir, f'{round_num}_generations.txt') utils.write_generation_log(log_data, log_file) # Compute loss completion_mask = attention_mask[:, prompt_ids.size(1):] loss, loss_metrics = compute_loss( model, base_model, prompt_completion_ids, prompt_ids, completion_ids, attention_mask, completion_mask, advantages, args ) # Combine metrics metrics.update(loss_metrics) return loss, metrics def parse_args(): parser = argparse.ArgumentParser(description="GRPO training arguments") # Model configuration parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct", help="Name/path of base model") parser.add_argument("--dataset_name", type=str, default="gsm8k", help="Dataset to use for training") parser.add_argument("--evaluator", type=str, default="gsm8k", help="Evaluator to use for scoring") # Output and logging parser.add_argument("--output_dir", type=str, default="output", help="Directory to save outputs") parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") parser.add_argument("--save_steps", type=int, default=100, help="Save model every N steps") parser.add_argument("--eval_iterations", type=int, default=20, help="Number of iterations for evaluation") # Optimization hyperparameters parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate") parser.add_argument("--adam_beta1", type=float, default=0.9, help="Adam beta1") parser.add_argument("--adam_beta2", type=float, default=0.99, help="Adam beta2") parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") parser.add_argument("--max_grad_norm", type=float, default=0.1, help="Max gradient norm for clipping") parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of gradient accumulation steps") parser.add_argument("--warmup_percent", type=float, default=0.18, help="Percentage of total steps for warmup") parser.add_argument("--update_ref_model", action="store_true", help="Whether to update reference model") parser.add_argument("--update_ref_model_freq", type=int, default=200, help="How often to update reference model") parser.add_argument("--ref_model_mixup_alpha", type=float, default=0.1, help="Alpha parameter for reference model mixup") # Generation parameters parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature") parser.add_argument("--num_chains", type=int, default=16, help="Number of parallel generation chains") parser.add_argument("--max_prompt_length", type=int, default=256, help="Maximum prompt length") parser.add_argument("--max_completion_length", type=int, default=786, help="Maximum completion length") # Training parameters parser.add_argument("--num_train_iters", type=int, default=1000, help="Number of training iterations") parser.add_argument("--kl_weight_beta", type=float, default=0.04, help="KL penalty weight") parser.add_argument("--seed", type=int, default=7111994, help="Random seed") args = parser.parse_args() return args if __name__ == "__main__": # Get all args args = parse_args() # Seed everything utils.seed_everything(args.seed) # Set device and enable bf16 device = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True torch.set_float32_matmul_precision('high') ############################### ## Main Experiment settings ## ############################### ## Set which model to train model, tokenizer = llms.get_llm_tokenizer(args.model_name, device) base_model, _ = llms.get_llm_tokenizer(args.model_name, device) ## Set which data set train_loader, test_loader = rldatasets.get_dataloaders(args.dataset_name) ## Set which evaluation criteria to use eval_class = evaluator.get_evaluator(args.evaluator) ############################### # Setup logging os.makedirs(args.output_dir, exist_ok=True) args_dict = vars(args) args_path = os.path.join(args.output_dir, 'args.json') with open(args_path, 'w') as f: json.dump(args_dict, f, indent=4) eval_log_dir = os.path.join(args.output_dir, 'eval_logs') os.makedirs(eval_log_dir, exist_ok=True) train_log_dir = os.path.join(args.output_dir, 'training_logs') os.makedirs(train_log_dir, exist_ok=True) # Setup optimizer for trainer agent with GRPO config settings optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.weight_decay, eps=1e-8 ) # Add linear warmup learning rate scheduler warmup_steps = int(args.warmup_percent * args.num_train_iters) def get_lr(step): if step < warmup_steps: return (step / warmup_steps) return 1.0 scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=get_lr) # Begin training accumulated_loss = 0 optimizer.zero_grad() train_metrics_total = {} for round_num in tqdm(range(args.num_train_iters), desc="Training Progress"): # Evaluate on test set every so often if round_num % args.eval_iterations == 0: eval_metrics, eval_accuracy = eval_on_test_set( model=model, tokenizer=tokenizer, test_loader=test_loader, eval_class=eval_class, device=device, args=args, round_num=round_num ) # Save metrics to eval log dir metrics_path = os.path.join(eval_log_dir, f'metrics_{round_num}.json') with open(metrics_path, 'w') as f: json.dump({ 'metrics': eval_metrics, 'accuracy': eval_accuracy }, f, indent=4) # Slowly update ref model if args.update_ref_model and (round_num+1) % args.update_ref_model_freq == 0: with torch.no_grad(): for param, ref_param in zip(model.parameters(), base_model.parameters()): ref_param.data = args.ref_model_mixup_alpha * param.data + (1 - args.ref_model_mixup_alpha) * ref_param.data # Get next question question, answer = next(train_loader) # Do GRPO - generate chains, score, compute advantage, compute loss total_loss, train_metrics = grpo_loss(model, base_model, tokenizer, question, answer, eval_class, device, round_num, train_log_dir, args) # Gradient accumulation total_loss = total_loss # / args.gradient_accumulation_steps total_loss.backward() accumulated_loss += total_loss.item() scheduler.step() # Step optimizer if (round_num + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() optimizer.zero_grad() # Logs train_metrics["learning_rate"] = scheduler.get_last_lr()[0] train_metrics["loss"] = total_loss.item() * args.gradient_accumulation_steps grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')).item() train_metrics["grad_norm"] = grad_norm train_metrics_total[round_num] = train_metrics with open(os.path.join(train_log_dir, "train_logs.json"), "w") as f: json.dump(train_metrics_total, f, indent=4) ================================================ FILE: plotter.py ================================================ import os import json import argparse import numpy as np import matplotlib.pyplot as plt import matplotlib.style as style from matplotlib.backends.backend_pdf import PdfPages def moving_average(data, window_size=5): """Calculate moving average with given window size""" weights = np.ones(window_size) / window_size return np.convolve(data, weights, mode='valid') def plot_metrics(output_dir): """ Plot training metrics from training_logs directory. Creates PDF with separate plots for each metric over training steps. Uses a modern, professional style with custom color palette. """ # Load training logs train_logs_path = os.path.join(output_dir, 'training_logs', 'train_logs.json') with open(train_logs_path, 'r') as f: train_logs = json.load(f) # Load evaluation logs eval_logs = {} eval_logs_dir = os.path.join(output_dir, 'eval_logs') for filename in os.listdir(eval_logs_dir): if filename.startswith('metrics_') and filename.endswith('.json'): step = int(filename.split('_')[1].split('.')[0]) with open(os.path.join(eval_logs_dir, filename), 'r') as f: eval_logs[step] = json.load(f) # Set style and color palette plt.style.use('bmh') # Using 'bmh' style which is a modern, clean style colors = ['#2ecc71', '#e74c3c', '#3498db', '#f1c40f', '#9b59b6', '#1abc9c', '#e67e22', '#34495e'] # Create PDF to save all plots pdf_path = os.path.join(output_dir, 'training_plots.pdf') with PdfPages(pdf_path) as pdf: # Plot reward metrics reward_metrics = [ 'rewards/correctness_reward_func', 'rewards/int_reward_func', 'rewards/strict_format_reward_func', 'rewards/soft_format_reward_func', 'rewards/xmlcount_reward_func', 'reward' ] for metric, color in zip(reward_metrics, colors): plt.figure(figsize=(12,7)) steps = [int(x) for x in train_logs.keys()] values = [metrics[metric] for metrics in train_logs.values()] # Plot raw data with low alpha plt.plot(steps, values, color=color, alpha=0.3, linewidth=1.5, label='Raw data') # Calculate and plot moving average if we have enough data points if len(values) > 5: ma_values = moving_average(values) ma_steps = steps[len(steps)-len(ma_values):] plt.plot(ma_steps, ma_values, color=color, linewidth=2.5, label='Moving average') plt.xlabel('Training Steps', fontsize=12) plt.ylabel(f'{metric.split("/")[-1].replace("_", " ").title()}', fontsize=12) plt.title(f'{metric.split("/")[-1].replace("_", " ").title()}', fontsize=14, pad=20) plt.grid(True, alpha=0.3) plt.legend() pdf.savefig(bbox_inches='tight') plt.close() # Plot learning rate plt.figure(figsize=(12,7)) steps = [int(x) for x in train_logs.keys()] lr_values = [metrics['learning_rate'] for metrics in train_logs.values()] plt.plot(steps, lr_values, color='#e74c3c', linewidth=2.0, label='Learning Rate') plt.xlabel('Training Steps', fontsize=12) plt.ylabel('Learning Rate', fontsize=12) plt.title('Learning Rate Schedule', fontsize=14, pad=20) plt.grid(True, alpha=0.3) plt.legend() pdf.savefig(bbox_inches='tight') plt.close() # Plot reward standard deviation plt.figure(figsize=(12,7)) reward_std = [metrics['reward_std'] for metrics in train_logs.values()] plt.plot(steps, reward_std, color='#3498db', alpha=0.3, linewidth=1.5, label='Reward Std (Raw)') if len(reward_std) > 5: ma_std = moving_average(reward_std) ma_steps = steps[len(steps)-len(ma_std):] plt.plot(ma_steps, ma_std, color='#3498db', linewidth=2.5, label='Reward Std (MA)') plt.xlabel('Training Steps', fontsize=12) plt.ylabel('Standard Deviation', fontsize=12) plt.title('Reward Standard Deviation', fontsize=14, pad=20) plt.grid(True, alpha=0.3) plt.legend() pdf.savefig(bbox_inches='tight') plt.close() # Plot loss plt.figure(figsize=(12,7)) loss_values = [metrics['loss'] for metrics in train_logs.values()] plt.plot(steps, loss_values, color='#e67e22', alpha=0.3, linewidth=1.5, label='Loss (Raw)') if len(loss_values) > 5: ma_loss = moving_average(loss_values) ma_steps = steps[len(steps)-len(ma_loss):] plt.plot(ma_steps, ma_loss, color='#e67e22', linewidth=2.5, label='Loss (MA)') plt.xlabel('Training Steps', fontsize=12) plt.ylabel('Loss', fontsize=12) plt.title('Training Loss', fontsize=14, pad=20) plt.grid(True, alpha=0.3) plt.legend() pdf.savefig(bbox_inches='tight') plt.close() # Plot KL divergence plt.figure(figsize=(12,7)) kl_values = [metrics['kl'] for metrics in train_logs.values()] plt.plot(steps, kl_values, color='#9b59b6', alpha=0.3, linewidth=1.5, label='KL Divergence (Raw)') if len(kl_values) > 5: ma_kl = moving_average(kl_values) ma_steps = steps[len(steps)-len(ma_kl):] plt.plot(ma_steps, ma_kl, color='#9b59b6', linewidth=2.5, label='KL Divergence (MA)') plt.xlabel('Training Steps', fontsize=12) plt.ylabel('KL Divergence', fontsize=12) plt.title('KL Divergence', fontsize=14, pad=20) plt.grid(True, alpha=0.3) plt.legend() pdf.savefig(bbox_inches='tight') plt.close() # Plot evaluation metrics if eval_logs: eval_steps = sorted(eval_logs.keys()) # Plot accuracy plt.figure(figsize=(12,7)) accuracy_values = [eval_logs[step]['accuracy'] for step in eval_steps] plt.plot(eval_steps, accuracy_values, color='#2ecc71', linewidth=2.0, label='Accuracy') plt.xlabel('Training Steps', fontsize=12) plt.ylabel('Accuracy (%)', fontsize=12) plt.title('Evaluation Accuracy', fontsize=14, pad=20) plt.grid(True, alpha=0.3) plt.legend() pdf.savefig(bbox_inches='tight') plt.close() # Plot evaluation reward metrics eval_metrics = [key for key in eval_logs[eval_steps[0]]['metrics'].keys()] for metric, color in zip(eval_metrics, colors): plt.figure(figsize=(12,7)) metric_values = [eval_logs[step]['metrics'][metric] for step in eval_steps] plt.plot(eval_steps, metric_values, color=color, linewidth=2.0, label=metric) plt.xlabel('Training Steps', fontsize=12) plt.ylabel(metric.replace('_', ' ').title(), fontsize=12) plt.title(f'Evaluation {metric.replace("_", " ").title()}', fontsize=14, pad=20) plt.grid(True, alpha=0.3) plt.legend() pdf.savefig(bbox_inches='tight') plt.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description='Plot training metrics from logs directory') parser.add_argument('--log_dir', type=str, help='Directory containing training logs') args = parser.parse_args() plot_metrics(args.log_dir) ================================================ FILE: requirements.txt ================================================ absl-py==2.1.0 accelerate==1.3.0 aiohappyeyeballs==2.4.4 aiohttp==3.11.11 aiosignal==1.3.2 annotated-types==0.7.0 anyio==4.8.0 appdirs==1.4.4 argcomplete==1.8.1 astunparse==1.6.3 async-timeout==5.0.1 attrs==21.2.0 Automat==20.2.0 Babel==2.8.0 backcall==0.2.0 bcrypt==3.2.0 beautifulsoup4==4.10.0 beniget==0.4.1 bleach==4.1.0 blinker==1.4 bottle==0.12.19 Brotli==1.0.9 certifi==2020.6.20 cffi==1.15.0 chardet==4.0.0 charset-normalizer==3.4.1 click==8.0.3 cloud-init==24.4 colorama==0.4.4 command-not-found==0.3 commonmark==0.9.1 configobj==5.0.6 constantly==15.1.0 cryptography==3.4.8 ctop==1.0.0 cycler==0.11.0 datasets==3.2.0 dbus-python==1.2.18 decorator==4.4.2 defusedxml==0.7.1 dill==0.3.8 distlib==0.3.4 distro==1.7.0 distro-info==1.1+ubuntu0.2 docker==5.0.3 docker-pycreds==0.4.0 einops==0.8.0 entrypoints==0.4 exceptiongroup==1.2.2 filelock==3.6.0 flake8==4.0.1 flash_attn==2.7.4.post1 flatbuffers===1.12.1-git20200711.33e2d80-dfsg1-0.6 fonttools==4.29.1 fpdf==1.7.2 frozenlist==1.5.0 fs==2.4.12 fsspec==2024.3.1 future==0.18.2 gast==0.5.2 gitdb==4.0.12 GitPython==3.1.44 Glances==3.2.4.2 google-pasta==0.2.0 grpcio==1.30.2 h11==0.14.0 h5py==3.6.0 h5py.-debian-h5py-serial==3.6.0 html5lib==1.1 httpcore==1.0.7 httplib2==0.20.2 httpx==0.28.1 huggingface-hub==0.28.1 hyperlink==21.0.0 icdiff==2.0.4 idna==3.3 importlib-metadata==4.6.4 incremental==21.3.0 influxdb==5.3.1 iotop==0.6 ipykernel==6.7.0 ipython==7.31.1 ipython_genutils==0.2.0 jax==0.4.30 jaxlib==0.4.30 jedi==0.18.0 jeepney==0.7.1 Jinja2==3.1.5 jiter==0.8.2 joblib==0.17.0 jsonpatch==1.32 jsonpointer==2.0 jsonschema==3.2.0 jupyter-client==7.1.2 jupyter-core==4.9.1 kaptan==0.5.12 keras==3.6.0 keyring==23.5.0 kiwisolver==1.3.2 launchpadlib==1.10.16 lazr.restfulclient==0.14.4 lazr.uri==1.0.6 libtmux==0.10.1 livereload==2.6.3 lxml==4.8.0 lz4==3.1.3+dfsg Markdown==3.3.6 MarkupSafe==2.0.1 matplotlib==3.5.1 matplotlib-inline==0.1.3 mccabe==0.6.1 mkdocs==1.1.2 ml-dtypes==0.5.0 more-itertools==8.10.0 mpmath==0.0.0 msgpack==1.0.3 multidict==6.1.0 multiprocess==0.70.16 namex==0.0.8 nest-asyncio==1.5.4 netifaces==0.11.0 networkx==2.4 numpy==1.21.5 nvidia-ml-py==12.555.43 oauthlib==3.2.0 olefile==0.46 openai==1.61.0 opt-einsum==3.3.0 optree==0.13.1 packaging==21.3 pandas==1.3.5 parso==0.8.1 peft==0.14.0 pexpect==4.8.0 pickleshare==0.7.5 Pillow==9.0.1 pipx==1.0.0 platformdirs==2.5.1 ply==3.11 prompt-toolkit==3.0.28 propcache==0.2.1 protobuf==4.21.12 psutil==5.9.0 ptyprocess==0.7.0 py==1.10.0 pyarrow==19.0.0 pyasn1==0.4.8 pyasn1-modules==0.2.1 pycodestyle==2.8.0 pycparser==2.21 pycryptodomex==3.11.0 pydantic==2.10.6 pydantic_core==2.27.2 pyflakes==2.4.0 Pygments==2.11.2 PyGObject==3.42.1 PyHamcrest==2.0.2 pyinotify==0.9.6 PyJWT==2.3.0 pyOpenSSL==21.0.0 pyparsing==2.4.7 pyrsistent==0.18.1 pyserial==3.5 pysmi==0.3.2 pysnmp==4.4.12 pystache==0.6.0 python-apt==2.4.0+ubuntu4 python-dateutil==2.8.1 python-magic==0.4.24 pythran==0.10.0 pytz==2022.1 PyYAML==5.4.1 pyzmq==22.3.0 regex==2024.11.6 requests==2.32.3 rich==11.2.0 safetensors==0.5.2 scikit-learn==0.23.2 scipy==1.8.0 SecretStorage==3.3.1 sentry-sdk==2.20.0 service-identity==18.1.0 setproctitle==1.3.4 six==1.16.0 smmap==5.0.2 sniffio==1.3.1 sos==4.7.2 soupsieve==2.3.1 ssh-import-id==5.11 sympy==1.12 tensorboard==2.18.0 tensorflow==2.18.0 termcolor==1.1.0 tf_keras==2.18.0 threadpoolctl==3.1.0 tmuxp==1.9.2 tokenizers==0.21.0 torch==2.5.1 torchvision==0.20.1 tornado==6.1 tqdm==4.67.1 traitlets==5.1.1 transformers==4.48.2 triton==3.1.0 trl==0.14.0 Twisted==22.1.0 typing_extensions==4.12.2 ufoLib2==0.13.1 ufw==0.36.1 unattended-upgrades==0.1 unicodedata2==14.0.0 urllib3==2.3.0 userpath==1.8.0 virtualenv==20.13.0+ds wadllib==1.3.6 wandb==0.19.5 wcwidth==0.2.5 webencodings==0.5.1 websocket-client==1.2.3 Werkzeug==2.0.2 wrapt==1.13.3 xxhash==3.5.0 yarl==1.18.3 zipp==1.0.0 zope.interface==5.4.0 ================================================ FILE: rldatasets.py ================================================ """ Hold all data sets """ import random import numpy as np from tqdm import tqdm from datasets import load_dataset, Dataset from abc import ABC, abstractmethod from typing import Tuple, Any class DataLoader(ABC): """ Abstract base class for data loaders. This class defines the interface that all dataset loaders should implement. Specific dataset loaders should inherit from this class and implement the required methods. Attributes: random (bool): If True, returns items randomly; if False, returns sequentially current_index (int): Current position for sequential access """ def __init__(self, random: bool = False) -> None: self.random = random self.current_index = 0 @abstractmethod def __len__(self) -> int: """Return the total number of items in the dataset.""" pass @abstractmethod def __iter__(self) -> 'DataLoader': """Return self as iterator.""" return self @abstractmethod def __next__(self) -> Any: """Return the next item(s) in the dataset.""" pass def extract_hash_answer(text: str) -> str | None: if "####" not in text: return None return text.split("####")[1].strip() SYSTEM_PROMPT = """ Respond in the following format: ... ... """ class GSM8KLoader(DataLoader): """ A loader class that provides iteration over GSM8K math problems. This class implements both sequential and random access to math problems through standard Python iterator protocols. It can be used to iterate over problems either in order or randomly, making it suitable for both training and evaluation. Attributes: questions (List[str]): List of math question strings answers (List[str]): List of corresponding answer strings random (bool): If True, returns problems randomly; if False, returns sequentially current_index (int): Current position in the lists for sequential access """ def __init__(self, questions: list[str], answers: list[str], random: bool = False) -> None: super().__init__(random) self.questions = questions self.answers = answers self.pre_prompt = """You will be given a question that involves reasoning. You should reason carefully about the question, then provide your answer. It is very important that you put your reasoning process inside tags and your final answer inside tags, like this: Your step-by-step reasoning process here Your final answer here All of your returned text should either be in the or tags - no text outside! Start each answer by immediately starting with . It is is extremely important you answer in this way - do not put any information or text outside of these tags! Question: """ self.system_prompt = SYSTEM_PROMPT def __len__(self) -> int: return len(self.questions) def __iter__(self) -> 'GSM8KLoader': return self def __next__(self) -> tuple[str, str]: if self.current_index >= len(self.questions): raise StopIteration if self.random: idx = random.randint(0, len(self.questions) - 1) else: idx = self.current_index self.current_index += 1 return self.questions[idx], self.answers[idx] def reset(self): self.current_index = 0 def build_gsm8k_dataloaders() -> Tuple[GSM8KLoader, GSM8KLoader]: data = load_dataset('openai/gsm8k', 'main')["train"] questions = [] parsed_answers = [] for i in tqdm(range(len(data)), desc="Processing"): # Try to get answer - if is None dont use this sample ans = extract_hash_answer(data[i]['answer']) if ans is None: continue else: questions.append(data[i]['question']) parsed_answers.append(ans) # Randomly split into train/test sets total_samples = len(questions) test_size = int(total_samples * 0.01) # 10% for test set # Generate random indices for test set test_indices = random.sample(range(total_samples), test_size) test_indices_set = set(test_indices) # Convert to numpy arrays for easier indexing questions = np.array(questions) parsed_answers = np.array(parsed_answers) # Create boolean mask for test indices test_mask = np.zeros(total_samples, dtype=bool) test_mask[list(test_indices_set)] = True # Split using boolean indexing test_questions = questions[test_mask] test_answers = parsed_answers[test_mask] train_questions = questions[~test_mask] train_answers = parsed_answers[~test_mask] # Setup data loaders trainloader = GSM8KLoader(train_questions.tolist(), train_answers.tolist()) testloader = GSM8KLoader(test_questions.tolist(), test_answers.tolist()) return trainloader, testloader def get_dataloaders(dataset_name: str) -> Tuple[DataLoader, DataLoader]: """ Factory function to get train and test data loaders for a specified dataset. Args: dataset_name (str): Name of the dataset to load ('gsm8k' currently supported) Returns: Tuple[DataLoader, DataLoader]: Train and test data loaders Raises: ValueError: If dataset_name is not supported """ if dataset_name.lower() == 'gsm8k': return build_gsm8k_dataloaders() else: raise ValueError(f"Dataset {dataset_name} not supported. Currently only 'gsm8k' is available.") if __name__ == "__main__": trainloader, testloader = get_dataloaders('gsm8k') ================================================ FILE: run.sh ================================================ # python main.py --output_dir "final1" --verbose python plotter.py --log_dir "final1" ================================================ FILE: utils.py ================================================ import os import torch import random import numpy as np import torch.nn.functional as F from typing import Any, Dict, Optional import re #################### ## MISC FUNCTIONS ## #################### def clean_spaces_preserve_newlines(text): # Replace multiple spaces with a single space, but preserve newlines lines = text.split("\n") # Split by newlines cleaned_lines = [" ".join(re.split(r"\s+", line)).strip() for line in lines] # Remove extra spaces in each line return "\n".join(cleaned_lines) # Join the lines back with newlines def seed_everything(seed: int) -> None: """ Set random seed for reproducibility across multiple libraries. This function sets consistent random seeds for Python's random module, NumPy, PyTorch (both CPU and CUDA), and configures CUDNN for deterministic operation. This ensures reproducible results across multiple runs. Args: seed: The random seed to use for all random number generators """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Additional settings for reproducibility torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def write_generation_log(log_data: Dict[str, Any], log_file: str) -> None: """ Write generation log data to a text file. Args: log_data: Dictionary containing prompt and generation data log_file: Path to output log file """ with open(log_file, 'w') as f: # Write prompt section f.write("###### ORIGINAL PROMPT #####\n\n") f.write(log_data['prompt']['text'] + "\n\n") f.write("#### ANS ####\n\n") f.write(str(log_data['prompt']['answer']) + "\n") # Write each generation for i, gen in enumerate(log_data['generations'], 1): f.write(f"#### GENERATION {i} RESPONSE ####\n\n") f.write(gen['response'] + "\n\n") f.write(f"#### GENERATION {i} SCORES ####\n") # Write individual scores f.write(f"Correctness: {gen['scores']['correctness']}\n") f.write(f"Integer format: {gen['scores']['integer_format']}\n") f.write(f"Strict format: {gen['scores']['strict_format']}\n") f.write(f"Soft format: {gen['scores']['soft_format']}\n") f.write(f"XML count: {gen['scores']['xml_count']}\n") f.write(f"Total reward: {gen['scores']['total_reward']}\n\n") #################################################################################### ## Copied Directly from TRL -> generate log probs per token ######## ## https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py ######## #################################################################################### def selective_log_softmax(logits, index): """ A memory-efficient implementation of the common `log_softmax -> gather` operation. This function is equivalent to the following naive implementation: ```python logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) ``` Args: logits (`torch.Tensor`): Logits tensor of shape `(..., num_classes)`. index (`torch.Tensor`): Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. Returns: `torch.Tensor`: Gathered log probabilities with the same shape as `index`. """ if logits.dtype in [torch.float32, torch.float64]: selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) else: # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach per_token_logps = [] for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption row_logps = F.log_softmax(row_logits, dim=-1) row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) per_token_logps.append(row_per_token_logps) per_token_logps = torch.stack(per_token_logps) return per_token_logps def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens