[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# UV\n#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#uv.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n# PyPI configuration file\n.pypirc\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2025 Brendan Hogan\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "\n# DeepSeek R1 Implementation\n\n## Motivation\nI wanted to recreate DeepSeek R1's  results at a smaller scale, focusing on understanding the core mechanics by implementing everything from scratch. So this is a repo that trains Qwen1.5B on the [grade school math dataset](https://github.com/openai/grade-school-math).\n\nThis implementation heavily borrows from [Will Brown's  work](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) ([@willccbb](https://x.com/willccbb)), but restructures the code into a format optimized for learning and experimentation.\n\nThe key difference in my implementation is computing the GRPO loss function directly rather than using external RL libraries, and reformatting into a multi script repo.\n\nI hope this might help other people understand things better, and maybe provide an easier way to try out smaller scale ideas etc. \n\n## Installation\n```\npip install -r requirements.txt\n```\n\nRequired environment variables:\n```\nexport HUGGINGFACE_TOKEN=\"your-token-here\"\nhuggingface-cli login\n```\n\n## Implementation Details\n\nThe system consists of several key modules:\n\n### main.py\nContains the core training loop implementing GRPO (Generalized Reward-Powered Optimization). Handles model training, evaluation, and metric tracking. \n\n### llms.py \nManages model loading and configuration, currently supporting LLaMA + Qwen models through Hugging Face's transformers library. Designed to be easily extensible to other model architectures.\n\n### rldatasets.py\nHandles dataset loading and preprocessing, currently focused on GSM8K math problems. Implements custom data loaders for both training and evaluation.\n\n### evaluator.py\nContains evaluation metrics and reward functions, closely following DeepSeek's original implementation.\n\n## Results\nTraining was conducted on a single H100 GPU. After ~400 training steps:\n\n![Training Results](plots/train_score.png)\n\nAnd results on the validation set - this shows a clearer sign of learning: \n![Eval Results](plots/eval_score.png)\n\n## Future Directions\nI'm really pleased to see how well the key mechanics work even in this simplified implementation. Building on this, I am very excited about several directions:\n\n1. Adding self-play capabilities where agents compete and learn from each other using relative rewards. This would create a more dynamic training environment where the reward signal comes from agent interactions rather than fixed metrics.\n\n2. Implementing soft reward structures, particularly for complex reasoning tasks. I've writing a framework for AI debate that I'm excited to try out.\n\n3. Expanding into vision-language models (VLMs) to improve world modeling capabilities. I have an idea about using R1-style training to enhance how VLMs build and maintain internal world models that I'm really excited to explore. (Really excited about this idea - if anyone else is interested I would love to talk.)\n\n4. I'd like to do all this experimentation in this framework, so I need to make things faster, and support multi-gpu training.\n\n\n\n"
  },
  {
    "path": "evaluator.py",
    "content": "\"\"\"\nAbstract base class and implementations for reward computation in RL training.\n\n\"\"\"\n\nimport re\nimport torch\nfrom abc import ABC, abstractmethod\nfrom typing import List, Dict, Tuple, Any\n\nclass RewardEvaluator(ABC):\n    \"\"\"\n    Abstract base class for reward computation in RL training.\n    \n    This class defines the interface for reward evaluators that can be used\n    to score model completions during RL training. Implement this class to\n    create custom reward functions for different tasks.\n    \n    The main methods that need to be implemented are:\n    - compute_rewards: Computes rewards for a batch of completions\n    - get_reward_breakdown: Converts raw reward scores to a labeled dictionary\n    \"\"\"\n    \n    @abstractmethod\n    def compute_rewards(\n        self,\n        prompts: List[List[Dict[str, str]]],\n        completions: List[List[Dict[str, str]]],\n        answer: Any,\n        device: str\n    ) -> Tuple[torch.Tensor, Dict[str, float]]:\n        \"\"\"\n        Compute rewards for a batch of completions.\n        \n        Args:\n            prompts: List of prompt messages in chat format\n                    [{\"role\": \"user\", \"content\": \"...\"}, ...]\n            completions: List of completion messages in chat format\n                        [{\"role\": \"assistant\", \"content\": \"...\"}, ...]\n            answer: Ground truth answer(s) for the prompts\n            device: Device to place tensors on (\"cpu\" or \"cuda\")\n            \n        Returns:\n            rewards_per_func: Tensor of shape (num_completions, num_reward_functions)\n                            containing individual reward function scores\n            metrics: Dictionary of aggregated metrics including mean rewards\n                    per function and total reward\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_reward_breakdown(self, reward_scores: torch.Tensor) -> Dict[str, float]:\n        \"\"\"\n        Convert raw reward scores tensor to a labeled dictionary.\n        \n        Args:\n            reward_scores: Tensor of raw scores from compute_rewards\n            \n        Returns:\n            Dictionary mapping reward function names to their scores\n        \"\"\"\n        pass\n\n\ndef get_evaluator(name: str) -> RewardEvaluator:\n    \"\"\"\n    Get the appropriate reward evaluator for a given task.\n    \n    Args:\n        name: Name of the task/dataset to get evaluator for\n        \n    Returns:\n        RewardEvaluator instance for the specified task\n        \n    Raises:\n        NotImplementedError: If evaluator for given task is not implemented\n    \"\"\"\n    if name.lower() == \"gsm8k\":\n        return GSM8kEvaluator()\n    else:\n        raise NotImplementedError(f\"No evaluator implemented for {name}\")\n\n\n\nclass GSM8kEvaluator(RewardEvaluator):\n    \"\"\"\n    Reward evaluator for the GSM8K math problem dataset.\n    \n    Implements reward functions for:\n    - Answer correctness\n    - Integer format validation\n    - XML formatting (strict and soft)\n    - XML tag counting\n    \"\"\"\n    \n    def __init__(self):\n        self.num_reward_functions = 5\n    \n    def _extract_xml_answer(self, text: str) -> str:\n        \"\"\"Extract answer from XML tags.\"\"\"\n        answer = text.split(\"<answer>\")[-1]\n        answer = answer.split(\"</answer>\")[0]\n        return answer.strip()\n    \n    def _correctness_reward(self, prompts, completions, answer) -> List[float]:\n        \"\"\"Reward for correct answer.\"\"\"\n        responses = [completion[0]['content'] for completion in completions]\n        extracted = [self._extract_xml_answer(r) for r in responses]\n        return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)]\n\n    def _int_format_reward(self, completions) -> List[float]:\n        \"\"\"Reward for integer format.\"\"\"\n        responses = [completion[0]['content'] for completion in completions]\n        extracted = [self._extract_xml_answer(r) for r in responses]\n        return [0.5 if r.isdigit() else 0.0 for r in extracted]\n\n    def _strict_format_reward(self, completions) -> List[float]:\n        \"\"\"Reward for strict XML format.\"\"\"\n        pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n        responses = [completion[0][\"content\"] for completion in completions]\n        matches = [bool(re.match(pattern, r)) for r in responses]\n        return [0.5 if m else 0.0 for m in matches]\n\n    def _soft_format_reward(self, completions) -> List[float]:\n        \"\"\"Reward for relaxed XML format.\"\"\"\n        pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n        responses = [completion[0][\"content\"] for completion in completions]\n        matches = [bool(re.match(pattern, r)) for r in responses]\n        return [0.5 if m else 0.0 for m in matches]\n\n    def _xml_count_reward(self, completions) -> List[float]:\n        \"\"\"Reward for XML tag counting.\"\"\"\n        def count_xml(text: str) -> float:\n            count = 0.0\n            if text.count(\"<reasoning>\\n\") == 1: count += 0.125\n            if text.count(\"\\n</reasoning>\\n\") == 1: count += 0.125\n            if text.count(\"\\n<answer>\\n\") == 1:\n                count += 0.125\n                count -= len(text.split(\"\\n</answer>\\n\")[-1])*0.001\n            if text.count(\"\\n</answer>\") == 1:\n                count += 0.125\n                count -= (len(text.split(\"\\n</answer>\")[-1]) - 1)*0.001\n            return count\n            \n        responses = [completion[0][\"content\"] for completion in completions]\n        return [count_xml(r) for r in responses]\n\n    def compute_rewards(\n        self,\n        prompts: List[List[Dict[str, str]]],\n        completions: List[List[Dict[str, str]]],\n        answer: Any,\n        device: str\n    ) -> Tuple[torch.Tensor, Dict[str, float]]:\n        \"\"\"Compute all rewards for the given completions.\"\"\"\n\n        num_completions = len(completions)\n        rewards_per_func = torch.zeros(num_completions, self.num_reward_functions, device=device)\n\n        # Compute all reward functions\n        all_scores = [\n            self._correctness_reward(prompts, completions, answer),\n            self._int_format_reward(completions),\n            self._strict_format_reward(completions),\n            self._soft_format_reward(completions),\n            self._xml_count_reward(completions)\n        ]\n        \n        # Fill rewards tensor\n        for i, scores in enumerate(all_scores):\n            rewards_per_func[:, i] = torch.tensor(scores, dtype=torch.float32, device=device)\n        \n        # Compute metrics\n        reward_per_func = rewards_per_func.mean(0)\n        \n        # Calculate accuracy (perfect correctness score)\n        correctness_scores = rewards_per_func[:, 0]  # First reward function is correctness\n        num_perfect = (correctness_scores == 2.0).sum().item()\n        accuracy = num_perfect / num_completions\n        \n        metrics = {\n            \"rewards/correctness_reward_func\": reward_per_func[0].item(),\n            \"rewards/int_reward_func\": reward_per_func[1].item(), \n            \"rewards/strict_format_reward_func\": reward_per_func[2].item(),\n            \"rewards/soft_format_reward_func\": reward_per_func[3].item(),\n            \"rewards/xmlcount_reward_func\": reward_per_func[4].item(),\n            \"reward\": rewards_per_func.sum(dim=1).mean().item(),\n            \"accuracy\": accuracy\n        }\n        \n        return rewards_per_func, metrics\n\n    def get_reward_breakdown(self, reward_scores: torch.Tensor) -> Dict[str, float]:\n        \"\"\"Convert reward scores tensor to labeled dictionary.\"\"\"\n        return {\n            'correctness': reward_scores[0].item(),\n            'integer_format': reward_scores[1].item(),\n            'strict_format': reward_scores[2].item(),\n            'soft_format': reward_scores[3].item(),\n            'xml_count': reward_scores[4].item()\n        }\n"
  },
  {
    "path": "llms.py",
    "content": "\"\"\"\nModule for loading LLMs and their tokenizers from huggingface. \n\n\"\"\"\nimport torch\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase\n\n\ndef get_llm_tokenizer(model_name: str, device: str) -> tuple[PreTrainedModel, PreTrainedTokenizerBase]:\n    \"\"\"\n    Load and configure a language model and its tokenizer.\n\n    Args:\n        model_name: Name or path of the pretrained model to load\n        device: Device to load the model on ('cpu' or 'cuda')\n\n    Returns:\n        tuple containing:\n            - The loaded language model\n            - The configured tokenizer for that model\n    \"\"\"\n    model = AutoModelForCausalLM.from_pretrained(\n        model_name,\n        torch_dtype=torch.bfloat16,\n        attn_implementation=\"flash_attention_2\",\n        device_map=None, \n    ).to(device)\n    \n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    tokenizer.pad_token = tokenizer.eos_token\n    model.config.pad_token_id = tokenizer.pad_token_id\n    model.config.use_cache = False\n\n    return model, tokenizer\n"
  },
  {
    "path": "main.py",
    "content": "\"\"\"\nImplementation of GRPO, DeepSeek style training without external libraries \n\"\"\"\nimport os\nimport json\nimport torch\nimport argparse\nfrom tqdm import tqdm\nfrom collections import defaultdict\nfrom transformers import PreTrainedModel, PreTrainedTokenizerBase, GenerationConfig\n\nimport llms\nimport utils\nimport evaluator\nimport rldatasets\n\ndef eval_on_test_set(\n    model: PreTrainedModel,\n    tokenizer: PreTrainedTokenizerBase,\n    test_loader: rldatasets.DataLoader,\n    eval_class: evaluator.RewardEvaluator,\n    device: str,\n    args: argparse.Namespace,\n    round_num: int\n) -> tuple[dict[str, float], float]:\n    \"\"\"\n    Evaluate model performance on test set.\n    \n    Args:\n        model: The model to evaluate\n        tokenizer: Tokenizer for the model\n        test_loader: DataLoader for test set\n        eval_class: Evaluator for computing rewards\n        device: Device to run on\n        args: Training arguments\n        round_num: Current training round number\n        \n    Returns:\n        total_scores: Dictionary of average metrics\n        accuracy: Accuracy on test set\n    \"\"\"\n    print(\"Running evaluation on test set...\")\n    \n    # Track metrics across all test examples\n    total_scores = defaultdict(float)\n    num_examples = 0\n    total_accuracy = 0.0\n\n    # Create log file for this evaluation round\n    log_file = os.path.join(args.output_dir, f'eval_metrics_{round_num}.txt')\n    test_loader.reset()\n    \n    with open(log_file, 'w') as f:\n        # Run through test set\n        for question, answer in tqdm(test_loader, desc=\"Evaluating on test set\"):\n            # Generate completions using same function as training\n            _, _, _, _, completions_text, _ = generate_completions(\n                model, tokenizer, question, device, args\n            )\n            \n            # Score completions using evaluator\n            mock_prompts = [[{'content': question}]] * len(completions_text)\n            mock_completions = [[{'content': completion}] for completion in completions_text]\n            # Make answer array same length as completions\n            answers = [answer] * len(completions_text)\n            rewards_per_func, metrics = eval_class.compute_rewards(\n                prompts=mock_prompts,\n                completions=mock_completions, \n                answer=answers,\n                device=device\n            )\n            \n            # Track accuracy and accumulate metrics\n            total_accuracy += metrics['accuracy']\n                \n            for k, v in metrics.items():\n                total_scores[k] += v\n            num_examples += 1\n\n            # Log this example\n            f.write(\"\\n\" + \"=\"*50 + \"\\n\")\n            f.write(f\"Q# {num_examples}\\n\")\n            f.write(f\"Question: {question}\\n\")\n            f.write(f\"Response: {completions_text[0]}\\n\") # Log first completion\n            f.write(f\"Ground Truth: {answer}\\n\")\n            f.write(\"Metrics:\\n\")\n            for metric, value in metrics.items():\n                f.write(f\"{metric}: {value}\\n\")\n            f.write(f\"Total Score: {rewards_per_func.sum().item()}\\n\")\n\n\n    # Calculate averages\n    avg_scores = {k: v/num_examples for k,v in total_scores.items()}\n    accuracy = total_accuracy / num_examples * 100\n\n    # Save metrics\n    metrics_path = os.path.join(args.output_dir, f'eval_metrics_{round_num}.json')\n    with open(metrics_path, 'w') as f:\n        json.dump({**avg_scores, 'accuracy': accuracy}, f, indent=4)\n\n    if args.verbose:\n        print(\"\\nEvaluation Results:\")\n        print(\"-\" * 20)\n        print(f\"Accuracy: {accuracy:.2f}%\")\n        for metric, value in avg_scores.items():\n            print(f\"{metric:15s}: {value:.4f}\")\n        print(\"-\" * 20)\n\n    return avg_scores, accuracy\n\ndef generate_completions(\n    model: PreTrainedModel,\n    tokenizer: PreTrainedTokenizerBase, \n    question: str,\n    device: str,\n    args: argparse.Namespace\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], str]:\n    \"\"\"\n    Generate multiple completion sequences for a given prompt using a language model.\n    \n    Args:\n        model: The language model to use for generation\n        tokenizer: Tokenizer corresponding to the model\n        question: The input question/prompt to generate completions for\n        device: Device to run generation on ('cpu' or 'cuda')\n        args: Namespace containing generation parameters\n        \n    Returns:\n        prompt_completion_ids: Tensor containing the full sequence of prompt + completion token IDs\n        prompt_ids: Tensor containing just the prompt token IDs\n        completion_ids: Tensor containing just the completion token IDs\n        attention_mask: Attention mask tensor for the full sequence\n        completions_text: List of decoded completion texts\n        prompt_text: The full formatted prompt text\n    \"\"\"\n    # 1. Prepare prompting\n    prompt = [\n        {'role': 'system', 'content': train_loader.system_prompt},\n        {'role': 'user', 'content': question}\n    ]\n    prompt_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)\n    prompt_inputs = tokenizer(prompt_text, return_tensors=\"pt\", padding=True, padding_side=\"left\", add_special_tokens=False)\n    prompt_ids, prompt_mask = prompt_inputs[\"input_ids\"], prompt_inputs[\"attention_mask\"]\n\n    # Truncate prompt to max length and repeat for number of generations\n    prompt_ids = prompt_ids[:, -args.max_prompt_length:]\n    prompt_mask = prompt_mask[:, -args.max_prompt_length:]\n    \n    # Repeat for number of chains/generations\n    prompt_ids = prompt_ids.repeat(args.num_chains, 1)\n    prompt_mask = prompt_mask.repeat(args.num_chains, 1)\n\n    # Move tensors to device\n    prompt_ids = prompt_ids.to(device)\n    prompt_mask = prompt_mask.to(device)\n\n    # Set up generation config\n    generation_config = GenerationConfig(\n        max_new_tokens=args.max_completion_length,\n        do_sample=True, \n        temperature=args.temperature,\n        pad_token_id=tokenizer.pad_token_id\n    )\n\n    # Generate completions\n    prompt_completion_ids = model.generate(\n        prompt_ids,\n        attention_mask=prompt_mask,\n        generation_config=generation_config\n    )\n\n    # Extract completion ids\n    prompt_length = prompt_ids.size(1)\n    prompt_ids = prompt_completion_ids[:, :prompt_length]\n    completion_ids = prompt_completion_ids[:, prompt_length:]\n\n    # Do masking \n    is_eos = completion_ids == tokenizer.eos_token_id\n    eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)\n    eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]\n    sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)\n    completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()\n\n    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)\n\n    # Decode completions\n    completions_text = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)\n\n    return prompt_completion_ids, prompt_ids, completion_ids, attention_mask, completions_text, prompt_text\n    \ndef score_completions(\n    completions_text: list[str],\n    question: str,\n    answer: str,\n    eval_class: evaluator.RewardEvaluator,\n    device: str,\n    args: argparse.Namespace\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float], dict]:\n    \"\"\"\n    Score model completions and compute advantages for training.\n    \n    Args:\n        completions_text: List of generated completion strings\n        question: Original input question/prompt\n        answer: Ground truth answer\n        eval_class: Evaluator class for computing rewards\n        device: Device to place tensors on\n        args: Training arguments\n        \n    Returns:\n        rewards: Raw reward scores for each completion\n        advantages: Computed advantages for policy gradient\n        rewards_per_func: Rewards broken down by individual reward functions\n        metrics: Dictionary of aggregated metrics\n        log_data: Dictionary containing detailed generation and scoring data\n    \"\"\"\n    # Build log data dictionary\n    log_data = {\n        'prompt': {\n            'text': question,\n            'answer': answer\n        },\n        'generations': []\n    }\n\n    # Format inputs as expected by evaluator\n    mock_prompts = [[{'content': question}]] * len(completions_text)\n    mock_completions = [[{'content': completion}] for completion in completions_text]\n    answers = [answer] * len(completions_text)\n    \n    # Get rewards and metrics from evaluator\n    rewards_per_func, metrics = eval_class.compute_rewards(\n        prompts=mock_prompts,\n        completions=mock_completions,\n        answer=answers,\n        device=device\n    )\n    rewards = rewards_per_func.sum(dim=1)\n\n    # Store generation data\n    for i, (completion, reward_scores) in enumerate(zip(completions_text, rewards_per_func)):\n        generation_data = {\n            'response': completion,\n            'scores': {\n                **eval_class.get_reward_breakdown(reward_scores),\n                'total_reward': rewards[i].item()\n            }\n        }\n        log_data['generations'].append(generation_data)\n\n    # Compute advantages\n    mean_grouped_rewards = rewards.view(-1, args.num_chains).mean(dim=1)\n    std_grouped_rewards = rewards.view(-1, args.num_chains).std(dim=1)\n\n    mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(args.num_chains, dim=0)\n    std_grouped_rewards = std_grouped_rewards.repeat_interleave(args.num_chains, dim=0)\n\n    advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)\n    metrics[\"reward_std\"] = std_grouped_rewards.mean().item()\n\n    # Store summary statistics\n    log_data['summary_stats'] = {\n        'mean_rewards_per_group': mean_grouped_rewards.tolist(),\n        'std_rewards_per_group': std_grouped_rewards.tolist(),\n        'advantages': advantages.tolist()\n    }\n\n    return rewards, advantages, rewards_per_func, metrics, log_data\n\ndef compute_loss(\n    model: PreTrainedModel,\n    base_model: PreTrainedModel, \n    prompt_completion_ids: torch.Tensor,\n    prompt_ids: torch.Tensor,\n    completion_ids: torch.Tensor,\n    attention_mask: torch.Tensor,\n    completion_mask: torch.Tensor,\n    advantages: torch.Tensor,\n    args: argparse.Namespace\n) -> tuple[torch.Tensor, dict[str, float]]:\n    \"\"\"\n    Compute the GRPO loss between current and base model.\n    \n    Args:\n        model: The current model being trained\n        base_model: The reference model to compare against\n        prompt_completion_ids: Combined prompt and completion token IDs\n        prompt_ids: Token IDs for just the prompt\n        completion_ids: Token IDs for just the completion\n        attention_mask: Attention mask for the full sequence\n        completion_mask: Mask indicating which tokens are from the completion\n        advantages: Advantage values for each sequence\n        args: Training arguments\n        \n    Returns:\n        loss: The computed GRPO loss\n        metrics: Dictionary containing additional metrics like KL divergence\n    \"\"\"\n\n    # Only need the generated tokens' logits\n    logits_to_keep = completion_ids.size(1)\n\n    # Get reference model logits\n    with torch.inference_mode():\n        ref_per_token_logps = utils.get_per_token_logps(base_model, prompt_completion_ids, attention_mask, logits_to_keep)\n\n    # Get training model logits\n    input_ids = torch.cat([prompt_ids, completion_ids], dim=1)\n    per_token_logps = utils.get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)\n\n    # Compute KL divergence\n    per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1\n\n    # Compute loss with advantages\n    per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)\n    per_token_loss = -(per_token_loss - args.kl_weight_beta * per_token_kl)\n    loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()\n\n    # Additional metrics\n    metrics = {}\n    response_length = completion_mask.sum(1).float().mean().item()\n    metrics[\"response_length\"] = response_length\n    mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()\n    metrics[\"kl\"] = mean_kl.item()\n\n    return loss, metrics\n\ndef grpo_loss(\n        model: PreTrainedModel,\n        base_model: PreTrainedModel,\n        tokenizer: PreTrainedTokenizerBase,\n        question: str,\n        answer: str,\n        eval_class: evaluator.RewardEvaluator,\n        device: str,\n        round_num: int,\n        training_log_dir: str, \n        args: argparse.Namespace\n) -> tuple[torch.Tensor, dict[str, float], float]:\n    \"\"\"\n    Compute GRPO loss between the current model and base model.\n    \n    Args:\n        model: The current model being trained\n        base_model: The reference model to compare against\n        tokenizer: Tokenizer for the models\n        question: Input question/prompt\n        answer: Ground truth answer\n        eval_class: Evaluator for computing rewards\n        device: Device to run on ('cpu' or 'cuda')\n        round_num: Current training round number\n        training_log_dir: Directory to save training logs\n        args: Training arguments\n        \n    Returns:\n        loss: The computed GRPO loss\n        metrics: Dictionary containing training metrics\n        reward: The total reward for this batch\n    \"\"\"\n    # Generate completions\n    prompt_completion_ids, prompt_ids, completion_ids, attention_mask, completions_text, prompt_text = generate_completions(\n        model, tokenizer, question, device, args\n    )\n\n    # Score completions\n    rewards, advantages, rewards_per_func, metrics, log_data = score_completions(\n        completions_text, question, answer, eval_class, device, args\n    )\n\n    # Write log data\n    log_file = os.path.join(training_log_dir, f'{round_num}_generations.txt')\n    utils.write_generation_log(log_data, log_file)\n\n    # Compute loss\n    completion_mask = attention_mask[:, prompt_ids.size(1):]\n    loss, loss_metrics = compute_loss(\n        model, base_model, prompt_completion_ids, prompt_ids, completion_ids,\n        attention_mask, completion_mask, advantages, args\n    )\n\n    # Combine metrics\n    metrics.update(loss_metrics)\n\n    return loss, metrics\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"GRPO training arguments\")\n    \n    # Model configuration\n    parser.add_argument(\"--model_name\", type=str, default=\"Qwen/Qwen2.5-1.5B-Instruct\", help=\"Name/path of base model\")\n    parser.add_argument(\"--dataset_name\", type=str, default=\"gsm8k\", help=\"Dataset to use for training\")\n    parser.add_argument(\"--evaluator\", type=str, default=\"gsm8k\", help=\"Evaluator to use for scoring\")\n\n    # Output and logging\n    parser.add_argument(\"--output_dir\", type=str, default=\"output\", help=\"Directory to save outputs\")\n    parser.add_argument(\"--verbose\", action=\"store_true\", help=\"Enable verbose logging\")\n    parser.add_argument(\"--save_steps\", type=int, default=100, help=\"Save model every N steps\")\n    parser.add_argument(\"--eval_iterations\", type=int, default=20, help=\"Number of iterations for evaluation\")\n\n    # Optimization hyperparameters\n    parser.add_argument(\"--learning_rate\", type=float, default=5e-6, help=\"Learning rate\")\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"Adam beta1\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.99, help=\"Adam beta2\") \n    parser.add_argument(\"--weight_decay\", type=float, default=0.1, help=\"Weight decay\")\n    parser.add_argument(\"--max_grad_norm\", type=float, default=0.1, help=\"Max gradient norm for clipping\")\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=4, help=\"Number of gradient accumulation steps\")\n    parser.add_argument(\"--warmup_percent\", type=float, default=0.18, help=\"Percentage of total steps for warmup\")\n    parser.add_argument(\"--update_ref_model\", action=\"store_true\", help=\"Whether to update reference model\")\n    parser.add_argument(\"--update_ref_model_freq\", type=int, default=200, help=\"How often to update reference model\")\n    parser.add_argument(\"--ref_model_mixup_alpha\", type=float, default=0.1, help=\"Alpha parameter for reference model mixup\")\n\n\n    # Generation parameters\n    parser.add_argument(\"--temperature\", type=float, default=0.9, help=\"Sampling temperature\")\n    parser.add_argument(\"--num_chains\", type=int, default=16, help=\"Number of parallel generation chains\")\n    parser.add_argument(\"--max_prompt_length\", type=int, default=256, help=\"Maximum prompt length\")\n    parser.add_argument(\"--max_completion_length\", type=int, default=786, help=\"Maximum completion length\")\n\n    # Training parameters\n    parser.add_argument(\"--num_train_iters\", type=int, default=1000, help=\"Number of training iterations\")\n    parser.add_argument(\"--kl_weight_beta\", type=float, default=0.04, help=\"KL penalty weight\")\n    parser.add_argument(\"--seed\", type=int, default=7111994, help=\"Random seed\")\n\n    args = parser.parse_args()\n    return args\n\nif __name__ == \"__main__\":\n\n    # Get all args \n    args = parse_args() \n    \n    # Seed everything \n    utils.seed_everything(args.seed)\n\n    # Set device and enable bf16\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True\n    torch.set_float32_matmul_precision('high') \n\n    ###############################\n    ## Main Experiment settings ##\n    ###############################\n\n    ## Set which model to train \n    model, tokenizer = llms.get_llm_tokenizer(args.model_name, device)\n    base_model, _ = llms.get_llm_tokenizer(args.model_name, device)\n\n    ## Set which data set \n    train_loader, test_loader = rldatasets.get_dataloaders(args.dataset_name)\n\n    ## Set which evaluation criteria to use \n    eval_class = evaluator.get_evaluator(args.evaluator)\n\n    ###############################\n\n\n    # Setup logging \n    os.makedirs(args.output_dir, exist_ok=True)\n    args_dict = vars(args)\n    args_path = os.path.join(args.output_dir, 'args.json')\n    with open(args_path, 'w') as f:\n        json.dump(args_dict, f, indent=4)\n    eval_log_dir = os.path.join(args.output_dir, 'eval_logs')\n    os.makedirs(eval_log_dir, exist_ok=True)\n    train_log_dir = os.path.join(args.output_dir, 'training_logs')\n    os.makedirs(train_log_dir, exist_ok=True)\n\n\n    # Setup optimizer for trainer agent with GRPO config settings\n    optimizer = torch.optim.AdamW(\n        model.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.weight_decay,\n        eps=1e-8\n    )\n\n    # Add linear warmup learning rate scheduler\n    warmup_steps = int(args.warmup_percent * args.num_train_iters)\n    def get_lr(step):\n        if step < warmup_steps:\n            return (step / warmup_steps)\n        return 1.0\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=get_lr)\n\n\n    # Begin training \n    accumulated_loss = 0\n    optimizer.zero_grad()\n    train_metrics_total = {}\n    for round_num in tqdm(range(args.num_train_iters), desc=\"Training Progress\"):\n    \n        # Evaluate on test set every so often \n        if round_num % args.eval_iterations == 0:\n            eval_metrics, eval_accuracy = eval_on_test_set(\n                model=model,\n                tokenizer=tokenizer, \n                test_loader=test_loader,\n                eval_class=eval_class,\n                device=device,\n                args=args,\n                round_num=round_num\n            )\n            \n            # Save metrics to eval log dir\n            metrics_path = os.path.join(eval_log_dir, f'metrics_{round_num}.json')\n            with open(metrics_path, 'w') as f:\n                json.dump({\n                    'metrics': eval_metrics,\n                    'accuracy': eval_accuracy\n                }, f, indent=4)\n\n        # Slowly update ref model\n        if args.update_ref_model and (round_num+1) % args.update_ref_model_freq == 0:\n            with torch.no_grad():\n                for param, ref_param in zip(model.parameters(), base_model.parameters()):\n                    ref_param.data = args.ref_model_mixup_alpha * param.data + (1 - args.ref_model_mixup_alpha) * ref_param.data\n\n        # Get next question\n        question, answer = next(train_loader)\n\n        # Do GRPO - generate chains, score, compute advantage, compute loss \n        total_loss, train_metrics = grpo_loss(model, base_model, tokenizer, question, answer, eval_class, device, round_num, train_log_dir, args)\n        \n        # Gradient accumulation\n        total_loss = total_loss # / args.gradient_accumulation_steps\n        total_loss.backward()\n        accumulated_loss += total_loss.item()\n        scheduler.step()\n\n        # Step optimizer\n        if (round_num + 1) % args.gradient_accumulation_steps == 0:\n            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n            optimizer.step()\n            optimizer.zero_grad()    \n\n        # Logs\n        train_metrics[\"learning_rate\"] = scheduler.get_last_lr()[0]\n        train_metrics[\"loss\"] = total_loss.item() * args.gradient_accumulation_steps\n        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')).item()\n        train_metrics[\"grad_norm\"] = grad_norm\n        train_metrics_total[round_num] = train_metrics\n        with open(os.path.join(train_log_dir, \"train_logs.json\"), \"w\") as f:\n            json.dump(train_metrics_total, f, indent=4)\n       \n"
  },
  {
    "path": "plotter.py",
    "content": "import os\nimport json\nimport argparse\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.style as style\nfrom matplotlib.backends.backend_pdf import PdfPages\n\ndef moving_average(data, window_size=5):\n    \"\"\"Calculate moving average with given window size\"\"\"\n    weights = np.ones(window_size) / window_size\n    return np.convolve(data, weights, mode='valid')\n\ndef plot_metrics(output_dir):\n    \"\"\"\n    Plot training metrics from training_logs directory.\n    Creates PDF with separate plots for each metric over training steps.\n    Uses a modern, professional style with custom color palette.\n    \"\"\"\n    # Load training logs\n    train_logs_path = os.path.join(output_dir, 'training_logs', 'train_logs.json')\n    with open(train_logs_path, 'r') as f:\n        train_logs = json.load(f)\n\n    # Load evaluation logs\n    eval_logs = {}\n    eval_logs_dir = os.path.join(output_dir, 'eval_logs')\n    for filename in os.listdir(eval_logs_dir):\n        if filename.startswith('metrics_') and filename.endswith('.json'):\n            step = int(filename.split('_')[1].split('.')[0])\n            with open(os.path.join(eval_logs_dir, filename), 'r') as f:\n                eval_logs[step] = json.load(f)\n\n    # Set style and color palette\n    plt.style.use('bmh')  # Using 'bmh' style which is a modern, clean style\n    colors = ['#2ecc71', '#e74c3c', '#3498db', '#f1c40f', '#9b59b6', '#1abc9c', '#e67e22', '#34495e']\n    \n    # Create PDF to save all plots\n    pdf_path = os.path.join(output_dir, 'training_plots.pdf')\n    with PdfPages(pdf_path) as pdf:\n        \n        # Plot reward metrics\n        reward_metrics = [\n            'rewards/correctness_reward_func',\n            'rewards/int_reward_func', \n            'rewards/strict_format_reward_func',\n            'rewards/soft_format_reward_func',\n            'rewards/xmlcount_reward_func',\n            'reward'\n        ]\n        \n        for metric, color in zip(reward_metrics, colors):\n            plt.figure(figsize=(12,7))\n            steps = [int(x) for x in train_logs.keys()]\n            values = [metrics[metric] for metrics in train_logs.values()]\n            \n            # Plot raw data with low alpha\n            plt.plot(steps, values, color=color, alpha=0.3, linewidth=1.5, label='Raw data')\n            \n            # Calculate and plot moving average if we have enough data points\n            if len(values) > 5:\n                ma_values = moving_average(values)\n                ma_steps = steps[len(steps)-len(ma_values):]\n                plt.plot(ma_steps, ma_values, color=color, linewidth=2.5, label='Moving average')\n            \n            plt.xlabel('Training Steps', fontsize=12)\n            plt.ylabel(f'{metric.split(\"/\")[-1].replace(\"_\", \" \").title()}', fontsize=12)\n            plt.title(f'{metric.split(\"/\")[-1].replace(\"_\", \" \").title()}', fontsize=14, pad=20)\n            plt.grid(True, alpha=0.3)\n            plt.legend()\n            pdf.savefig(bbox_inches='tight')\n            plt.close()\n\n        # Plot learning rate\n        plt.figure(figsize=(12,7))\n        steps = [int(x) for x in train_logs.keys()]\n        lr_values = [metrics['learning_rate'] for metrics in train_logs.values()]\n\n        plt.plot(steps, lr_values, color='#e74c3c', linewidth=2.0, label='Learning Rate')\n        \n        plt.xlabel('Training Steps', fontsize=12)\n        plt.ylabel('Learning Rate', fontsize=12)\n        plt.title('Learning Rate Schedule', fontsize=14, pad=20)\n        plt.grid(True, alpha=0.3)\n        plt.legend()\n        pdf.savefig(bbox_inches='tight')\n        plt.close()\n\n        # Plot reward standard deviation\n        plt.figure(figsize=(12,7))\n        reward_std = [metrics['reward_std'] for metrics in train_logs.values()]\n\n        plt.plot(steps, reward_std, color='#3498db', alpha=0.3, linewidth=1.5, label='Reward Std (Raw)')\n        if len(reward_std) > 5:\n            ma_std = moving_average(reward_std)\n            ma_steps = steps[len(steps)-len(ma_std):]\n            plt.plot(ma_steps, ma_std, color='#3498db', linewidth=2.5, label='Reward Std (MA)')\n\n        plt.xlabel('Training Steps', fontsize=12)\n        plt.ylabel('Standard Deviation', fontsize=12)\n        plt.title('Reward Standard Deviation', fontsize=14, pad=20)\n        plt.grid(True, alpha=0.3)\n        plt.legend()\n        pdf.savefig(bbox_inches='tight')\n        plt.close()\n\n        # Plot loss\n        plt.figure(figsize=(12,7))\n        loss_values = [metrics['loss'] for metrics in train_logs.values()]\n\n        plt.plot(steps, loss_values, color='#e67e22', alpha=0.3, linewidth=1.5, label='Loss (Raw)')\n        if len(loss_values) > 5:\n            ma_loss = moving_average(loss_values)\n            ma_steps = steps[len(steps)-len(ma_loss):]\n            plt.plot(ma_steps, ma_loss, color='#e67e22', linewidth=2.5, label='Loss (MA)')\n\n        plt.xlabel('Training Steps', fontsize=12)\n        plt.ylabel('Loss', fontsize=12)\n        plt.title('Training Loss', fontsize=14, pad=20)\n        plt.grid(True, alpha=0.3)\n        plt.legend()\n        pdf.savefig(bbox_inches='tight')\n        plt.close()\n\n        # Plot KL divergence\n        plt.figure(figsize=(12,7))\n        kl_values = [metrics['kl'] for metrics in train_logs.values()]\n\n        plt.plot(steps, kl_values, color='#9b59b6', alpha=0.3, linewidth=1.5, label='KL Divergence (Raw)')\n        if len(kl_values) > 5:\n            ma_kl = moving_average(kl_values)\n            ma_steps = steps[len(steps)-len(ma_kl):]\n            plt.plot(ma_steps, ma_kl, color='#9b59b6', linewidth=2.5, label='KL Divergence (MA)')\n\n        plt.xlabel('Training Steps', fontsize=12)\n        plt.ylabel('KL Divergence', fontsize=12)\n        plt.title('KL Divergence', fontsize=14, pad=20)\n        plt.grid(True, alpha=0.3)\n        plt.legend()\n        pdf.savefig(bbox_inches='tight')\n        plt.close()\n\n        # Plot evaluation metrics\n        if eval_logs:\n            eval_steps = sorted(eval_logs.keys())\n            \n            # Plot accuracy\n            plt.figure(figsize=(12,7))\n            accuracy_values = [eval_logs[step]['accuracy'] for step in eval_steps]\n            plt.plot(eval_steps, accuracy_values, color='#2ecc71', linewidth=2.0, label='Accuracy')\n            plt.xlabel('Training Steps', fontsize=12)\n            plt.ylabel('Accuracy (%)', fontsize=12)\n            plt.title('Evaluation Accuracy', fontsize=14, pad=20)\n            plt.grid(True, alpha=0.3)\n            plt.legend()\n            pdf.savefig(bbox_inches='tight')\n            plt.close()\n\n            # Plot evaluation reward metrics\n            eval_metrics = [key for key in eval_logs[eval_steps[0]]['metrics'].keys()]\n            for metric, color in zip(eval_metrics, colors):\n                plt.figure(figsize=(12,7))\n                metric_values = [eval_logs[step]['metrics'][metric] for step in eval_steps]\n                plt.plot(eval_steps, metric_values, color=color, linewidth=2.0, label=metric)\n                plt.xlabel('Training Steps', fontsize=12)\n                plt.ylabel(metric.replace('_', ' ').title(), fontsize=12)\n                plt.title(f'Evaluation {metric.replace(\"_\", \" \").title()}', fontsize=14, pad=20)\n                plt.grid(True, alpha=0.3)\n                plt.legend()\n                pdf.savefig(bbox_inches='tight')\n                plt.close()\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='Plot training metrics from logs directory')\n    parser.add_argument('--log_dir', type=str, help='Directory containing training logs')\n    args = parser.parse_args()\n    plot_metrics(args.log_dir)\n"
  },
  {
    "path": "requirements.txt",
    "content": "absl-py==2.1.0\naccelerate==1.3.0\naiohappyeyeballs==2.4.4\naiohttp==3.11.11\naiosignal==1.3.2\nannotated-types==0.7.0\nanyio==4.8.0\nappdirs==1.4.4\nargcomplete==1.8.1\nastunparse==1.6.3\nasync-timeout==5.0.1\nattrs==21.2.0\nAutomat==20.2.0\nBabel==2.8.0\nbackcall==0.2.0\nbcrypt==3.2.0\nbeautifulsoup4==4.10.0\nbeniget==0.4.1\nbleach==4.1.0\nblinker==1.4\nbottle==0.12.19\nBrotli==1.0.9\ncertifi==2020.6.20\ncffi==1.15.0\nchardet==4.0.0\ncharset-normalizer==3.4.1\nclick==8.0.3\ncloud-init==24.4\ncolorama==0.4.4\ncommand-not-found==0.3\ncommonmark==0.9.1\nconfigobj==5.0.6\nconstantly==15.1.0\ncryptography==3.4.8\nctop==1.0.0\ncycler==0.11.0\ndatasets==3.2.0\ndbus-python==1.2.18\ndecorator==4.4.2\ndefusedxml==0.7.1\ndill==0.3.8\ndistlib==0.3.4\ndistro==1.7.0\ndistro-info==1.1+ubuntu0.2\ndocker==5.0.3\ndocker-pycreds==0.4.0\neinops==0.8.0\nentrypoints==0.4\nexceptiongroup==1.2.2\nfilelock==3.6.0\nflake8==4.0.1\nflash_attn==2.7.4.post1\nflatbuffers===1.12.1-git20200711.33e2d80-dfsg1-0.6\nfonttools==4.29.1\nfpdf==1.7.2\nfrozenlist==1.5.0\nfs==2.4.12\nfsspec==2024.3.1\nfuture==0.18.2\ngast==0.5.2\ngitdb==4.0.12\nGitPython==3.1.44\nGlances==3.2.4.2\ngoogle-pasta==0.2.0\ngrpcio==1.30.2\nh11==0.14.0\nh5py==3.6.0\nh5py.-debian-h5py-serial==3.6.0\nhtml5lib==1.1\nhttpcore==1.0.7\nhttplib2==0.20.2\nhttpx==0.28.1\nhuggingface-hub==0.28.1\nhyperlink==21.0.0\nicdiff==2.0.4\nidna==3.3\nimportlib-metadata==4.6.4\nincremental==21.3.0\ninfluxdb==5.3.1\niotop==0.6\nipykernel==6.7.0\nipython==7.31.1\nipython_genutils==0.2.0\njax==0.4.30\njaxlib==0.4.30\njedi==0.18.0\njeepney==0.7.1\nJinja2==3.1.5\njiter==0.8.2\njoblib==0.17.0\njsonpatch==1.32\njsonpointer==2.0\njsonschema==3.2.0\njupyter-client==7.1.2\njupyter-core==4.9.1\nkaptan==0.5.12\nkeras==3.6.0\nkeyring==23.5.0\nkiwisolver==1.3.2\nlaunchpadlib==1.10.16\nlazr.restfulclient==0.14.4\nlazr.uri==1.0.6\nlibtmux==0.10.1\nlivereload==2.6.3\nlxml==4.8.0\nlz4==3.1.3+dfsg\nMarkdown==3.3.6\nMarkupSafe==2.0.1\nmatplotlib==3.5.1\nmatplotlib-inline==0.1.3\nmccabe==0.6.1\nmkdocs==1.1.2\nml-dtypes==0.5.0\nmore-itertools==8.10.0\nmpmath==0.0.0\nmsgpack==1.0.3\nmultidict==6.1.0\nmultiprocess==0.70.16\nnamex==0.0.8\nnest-asyncio==1.5.4\nnetifaces==0.11.0\nnetworkx==2.4\nnumpy==1.21.5\nnvidia-ml-py==12.555.43\noauthlib==3.2.0\nolefile==0.46\nopenai==1.61.0\nopt-einsum==3.3.0\noptree==0.13.1\npackaging==21.3\npandas==1.3.5\nparso==0.8.1\npeft==0.14.0\npexpect==4.8.0\npickleshare==0.7.5\nPillow==9.0.1\npipx==1.0.0\nplatformdirs==2.5.1\nply==3.11\nprompt-toolkit==3.0.28\npropcache==0.2.1\nprotobuf==4.21.12\npsutil==5.9.0\nptyprocess==0.7.0\npy==1.10.0\npyarrow==19.0.0\npyasn1==0.4.8\npyasn1-modules==0.2.1\npycodestyle==2.8.0\npycparser==2.21\npycryptodomex==3.11.0\npydantic==2.10.6\npydantic_core==2.27.2\npyflakes==2.4.0\nPygments==2.11.2\nPyGObject==3.42.1\nPyHamcrest==2.0.2\npyinotify==0.9.6\nPyJWT==2.3.0\npyOpenSSL==21.0.0\npyparsing==2.4.7\npyrsistent==0.18.1\npyserial==3.5\npysmi==0.3.2\npysnmp==4.4.12\npystache==0.6.0\npython-apt==2.4.0+ubuntu4\npython-dateutil==2.8.1\npython-magic==0.4.24\npythran==0.10.0\npytz==2022.1\nPyYAML==5.4.1\npyzmq==22.3.0\nregex==2024.11.6\nrequests==2.32.3\nrich==11.2.0\nsafetensors==0.5.2\nscikit-learn==0.23.2\nscipy==1.8.0\nSecretStorage==3.3.1\nsentry-sdk==2.20.0\nservice-identity==18.1.0\nsetproctitle==1.3.4\nsix==1.16.0\nsmmap==5.0.2\nsniffio==1.3.1\nsos==4.7.2\nsoupsieve==2.3.1\nssh-import-id==5.11\nsympy==1.12\ntensorboard==2.18.0\ntensorflow==2.18.0\ntermcolor==1.1.0\ntf_keras==2.18.0\nthreadpoolctl==3.1.0\ntmuxp==1.9.2\ntokenizers==0.21.0\ntorch==2.5.1\ntorchvision==0.20.1\ntornado==6.1\ntqdm==4.67.1\ntraitlets==5.1.1\ntransformers==4.48.2\ntriton==3.1.0\ntrl==0.14.0\nTwisted==22.1.0\ntyping_extensions==4.12.2\nufoLib2==0.13.1\nufw==0.36.1\nunattended-upgrades==0.1\nunicodedata2==14.0.0\nurllib3==2.3.0\nuserpath==1.8.0\nvirtualenv==20.13.0+ds\nwadllib==1.3.6\nwandb==0.19.5\nwcwidth==0.2.5\nwebencodings==0.5.1\nwebsocket-client==1.2.3\nWerkzeug==2.0.2\nwrapt==1.13.3\nxxhash==3.5.0\nyarl==1.18.3\nzipp==1.0.0\nzope.interface==5.4.0\n"
  },
  {
    "path": "rldatasets.py",
    "content": "\"\"\"\nHold all data sets \n\n\"\"\"\n\nimport random\nimport numpy as np\nfrom tqdm import tqdm\nfrom datasets import load_dataset, Dataset\nfrom abc import ABC, abstractmethod\nfrom typing import Tuple, Any\n\n\n\nclass DataLoader(ABC):\n    \"\"\"\n    Abstract base class for data loaders.\n    \n    This class defines the interface that all dataset loaders should implement.\n    Specific dataset loaders should inherit from this class and implement the\n    required methods.\n    \n    Attributes:\n        random (bool): If True, returns items randomly; if False, returns sequentially\n        current_index (int): Current position for sequential access\n    \"\"\"\n    \n    def __init__(self, random: bool = False) -> None:\n        self.random = random\n        self.current_index = 0\n        \n    @abstractmethod\n    def __len__(self) -> int:\n        \"\"\"Return the total number of items in the dataset.\"\"\"\n        pass\n        \n    @abstractmethod\n    def __iter__(self) -> 'DataLoader':\n        \"\"\"Return self as iterator.\"\"\"\n        return self\n        \n    @abstractmethod\n    def __next__(self) -> Any:\n        \"\"\"Return the next item(s) in the dataset.\"\"\"\n        pass\n\n\ndef extract_hash_answer(text: str) -> str | None:\n    if \"####\" not in text:\n        return None\n    return text.split(\"####\")[1].strip()\n\n\n\nSYSTEM_PROMPT = \"\"\"\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n\"\"\"\n\n\n\nclass GSM8KLoader(DataLoader):\n    \"\"\"\n    A loader class that provides iteration over GSM8K math problems.\n    \n    This class implements both sequential and random access to math problems through\n    standard Python iterator protocols. It can be used to iterate over problems either\n    in order or randomly, making it suitable for both training and evaluation.\n    \n    Attributes:\n        questions (List[str]): List of math question strings\n        answers (List[str]): List of corresponding answer strings\n        random (bool): If True, returns problems randomly; if False, returns sequentially\n        current_index (int): Current position in the lists for sequential access\n    \"\"\"\n    \n    def __init__(self, questions: list[str], answers: list[str], random: bool = False) -> None:\n        super().__init__(random)\n        self.questions = questions\n        self.answers = answers\n        self.pre_prompt = \"\"\"You will be given a question that involves reasoning. You should reason carefully about the question, then provide your answer.\n            It is very important that you put your reasoning process inside <reasoning> tags and your final answer inside <answer> tags, like this:\n\n            \n            <reasoning>\n            Your step-by-step reasoning process here\n            </reasoning>\n            <answer>\n            Your final answer here\n            </answer>\n\n            All of your returned text should either be in the <reasoning> or <answer> tags - no text outside! Start each answer by immediately starting with <reasoning>. \n            It is is extremely important you answer in this way - do not put any information or text outside of these tags!\n\n            Question: \"\"\"\n        self.system_prompt = SYSTEM_PROMPT\n        \n    def __len__(self) -> int:\n        return len(self.questions)\n        \n    def __iter__(self) -> 'GSM8KLoader':\n        return self\n        \n    def __next__(self) -> tuple[str, str]:\n        if self.current_index >= len(self.questions):\n            raise StopIteration\n        \n        if self.random:\n            idx = random.randint(0, len(self.questions) - 1)\n        else:\n            idx = self.current_index\n            self.current_index += 1\n            \n        return self.questions[idx], self.answers[idx]\n\n    def reset(self):\n        self.current_index = 0 \n\n\ndef build_gsm8k_dataloaders() -> Tuple[GSM8KLoader, GSM8KLoader]: \n    data = load_dataset('openai/gsm8k', 'main')[\"train\"]\n\n    questions = []\n    parsed_answers = [] \n    for i in tqdm(range(len(data)), desc=\"Processing\"):\n        # Try to get answer - if is None dont use this sample \n        ans = extract_hash_answer(data[i]['answer'])\n        if ans is None: \n            continue \n        else:\n            questions.append(data[i]['question'])\n            parsed_answers.append(ans)\n\n    # Randomly split into train/test sets\n    total_samples = len(questions)\n    test_size = int(total_samples * 0.01)  # 10% for test set\n    \n    # Generate random indices for test set\n    test_indices = random.sample(range(total_samples), test_size)\n    test_indices_set = set(test_indices)\n    \n    # Convert to numpy arrays for easier indexing\n    questions = np.array(questions)\n    parsed_answers = np.array(parsed_answers)\n    \n    # Create boolean mask for test indices\n    test_mask = np.zeros(total_samples, dtype=bool)\n    test_mask[list(test_indices_set)] = True\n    \n    # Split using boolean indexing\n    test_questions = questions[test_mask]\n    test_answers = parsed_answers[test_mask]\n    train_questions = questions[~test_mask] \n    train_answers = parsed_answers[~test_mask]\n\n    # Setup data loaders \n    trainloader = GSM8KLoader(train_questions.tolist(), train_answers.tolist())\n    testloader = GSM8KLoader(test_questions.tolist(), test_answers.tolist())\n    \n    return trainloader, testloader\n\n\ndef get_dataloaders(dataset_name: str) -> Tuple[DataLoader, DataLoader]:\n    \"\"\"\n    Factory function to get train and test data loaders for a specified dataset.\n    \n    Args:\n        dataset_name (str): Name of the dataset to load ('gsm8k' currently supported)\n        \n    Returns:\n        Tuple[DataLoader, DataLoader]: Train and test data loaders\n        \n    Raises:\n        ValueError: If dataset_name is not supported\n    \"\"\"\n    if dataset_name.lower() == 'gsm8k':\n        return build_gsm8k_dataloaders()\n    else:\n        raise ValueError(f\"Dataset {dataset_name} not supported. Currently only 'gsm8k' is available.\")\n\n\nif __name__ == \"__main__\": \n    trainloader, testloader = get_dataloaders('gsm8k')"
  },
  {
    "path": "run.sh",
    "content": "# python main.py --output_dir \"final1\" --verbose\npython plotter.py --log_dir \"final1\""
  },
  {
    "path": "utils.py",
    "content": "import os\nimport torch\nimport random\nimport numpy as np\nimport torch.nn.functional as F\nfrom typing import Any, Dict, Optional\n\nimport re\n\n####################\n## MISC FUNCTIONS ##\n####################\n\ndef clean_spaces_preserve_newlines(text):\n    # Replace multiple spaces with a single space, but preserve newlines\n    lines = text.split(\"\\n\")  # Split by newlines\n    cleaned_lines = [\" \".join(re.split(r\"\\s+\", line)).strip() for line in lines]  # Remove extra spaces in each line\n    return \"\\n\".join(cleaned_lines)  # Join the lines back with newlines\n\n\n\ndef seed_everything(seed: int) -> None:\n    \"\"\"\n    Set random seed for reproducibility across multiple libraries.\n    \n    This function sets consistent random seeds for Python's random module,\n    NumPy, PyTorch (both CPU and CUDA), and configures CUDNN for deterministic\n    operation. This ensures reproducible results across multiple runs.\n\n    Args:\n        seed: The random seed to use for all random number generators\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    \n    # Additional settings for reproducibility\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\n\ndef write_generation_log(log_data: Dict[str, Any], log_file: str) -> None:\n    \"\"\"\n    Write generation log data to a text file.\n\n    Args:\n        log_data: Dictionary containing prompt and generation data\n        log_file: Path to output log file\n    \"\"\"\n    with open(log_file, 'w') as f:\n        # Write prompt section\n        f.write(\"###### ORIGINAL PROMPT #####\\n\\n\")\n        f.write(log_data['prompt']['text'] + \"\\n\\n\")\n        f.write(\"#### ANS ####\\n\\n\")\n        f.write(str(log_data['prompt']['answer']) + \"\\n\")\n\n        # Write each generation\n        for i, gen in enumerate(log_data['generations'], 1):\n            f.write(f\"#### GENERATION {i} RESPONSE ####\\n\\n\")\n            f.write(gen['response'] + \"\\n\\n\")\n            f.write(f\"#### GENERATION {i} SCORES ####\\n\")\n            \n            # Write individual scores\n            f.write(f\"Correctness: {gen['scores']['correctness']}\\n\")\n            f.write(f\"Integer format: {gen['scores']['integer_format']}\\n\") \n            f.write(f\"Strict format: {gen['scores']['strict_format']}\\n\")\n            f.write(f\"Soft format: {gen['scores']['soft_format']}\\n\")\n            f.write(f\"XML count: {gen['scores']['xml_count']}\\n\")\n            f.write(f\"Total reward: {gen['scores']['total_reward']}\\n\\n\")\n\n\n####################################################################################\n## Copied Directly from TRL -> generate log probs per token                 ########\n## https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py ########\n####################################################################################\n\ndef selective_log_softmax(logits, index):\n    \"\"\"\n    A memory-efficient implementation of the common `log_softmax -> gather` operation.\n\n    This function is equivalent to the following naive implementation:\n    ```python\n    logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)\n    ```\n\n    Args:\n        logits (`torch.Tensor`):\n            Logits tensor of shape `(..., num_classes)`.\n        index (`torch.Tensor`):\n            Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.\n\n    Returns:\n        `torch.Tensor`:\n            Gathered log probabilities with the same shape as `index`.\n    \"\"\"\n    if logits.dtype in [torch.float32, torch.float64]:\n        selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)\n        # loop to reduce peak mem consumption\n        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])\n        per_token_logps = selected_logits - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)\n    else:\n        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach\n        per_token_logps = []\n        for row_logits, row_labels in zip(logits, index):  # loop to reduce peak mem consumption\n            row_logps = F.log_softmax(row_logits, dim=-1)\n            row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)\n            per_token_logps.append(row_per_token_logps)\n        per_token_logps = torch.stack(per_token_logps)\n    return per_token_logps\n\ndef get_per_token_logps(model, input_ids, attention_mask, logits_to_keep):\n    # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded\n    logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits\n    logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred\n\n    input_ids = input_ids[:, -logits_to_keep:]\n    # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.\n    # See https://github.com/huggingface/trl/issues/2770\n    logits = logits[:, -logits_to_keep:]\n    return selective_log_softmax(logits, input_ids)  #  compute logprobs for the input tokens\n"
  }
]