[
  {
    "path": ".gitignore",
    "content": "__pycache__/\n"
  },
  {
    "path": "README.md",
    "content": "# Prompt-to-Leaderboard (P2L)\n\nThis is the codebase for the paper [Prompt-to-Leaderboard](https://arxiv.org/pdf/2502.14855).\n\nModels weights found at our [LMArena HF Collection](https://huggingface.co/collections/lmarena-ai/prompt-to-leaderboard-67bcf7ddf6022ef3cfd260cc).\n\nTry on Chatbot Arena at the [Prompt-to-Leaderboard](https://lmarena.ai/?p2l) tab!\n\n## Abstract\nLarge language model (LLM) evaluations typically rely on aggregated metrics like accuracy or human preference, averaging across users and prompts. This averaging obscures user- and prompt-specific variations in model performance.\nTo address this, we propose Prompt-to-Leaderboard (P2L), a method that produces leaderboards specific to a prompt or set of prompts.\nThe core idea is to train an LLM taking natural language prompts as input to output a vector of Bradley-Terry coefficients which are then used to predict the human preference vote.\nThe resulting prompt-dependent leaderboards allow for unsupervised task-specific evaluation, optimal routing of queries to models, personalization, and automated evaluation of model strengths and weaknesses. \nData from Chatbot Arena suggest that P2L better captures the nuanced landscape of language model performance than the averaged leaderboard. \nFurthermore, our findings suggest that P2L's ability to produce prompt-specific evaluations follows a power law scaling similar to that observed in LLMs themselves. In January 2025, the router we trained based on this methodology achieved the #1 spot in the Chatbot Arena leaderboard.\n\n## Table of Contents\n\n- [P2L](#p2l)\n  - [Abstract](#abstract)\n  - [Table of Contents](#table-of-contents)\n  - [Environment Setup](#environment-setup)\n    - [Installing `uv`](#installing-uv)\n    - [Serving P2L Setup](#serving-p2l-setup)\n    - [Serving a Router Setup](#serving-a-router-setup)\n    - [Training Setup](#training-setup)\n  - [Serving P2L](#serving-p2l)\n  - [Serving an OpenAI Compatible Router](#serving-an-openai-compatible-router)\n    - [Example: serving a Bradley-Terry based cost-optimal router](#example-serving-a-bradley-terry-based-cost-optimal-router)\n    - [Example: serving a Grounded RK based simple cost router](#example-serving-a-grounded-rk-based-simple-cost-router)\n  - [Calling the OpenAI Compatible Router](#calling-the-openai-compatible-router)\n  - [Training a P2L Model](#training-a-p2l-model)\n  - [Inferencing a P2L Model](#inferencing-a-p2l-model)\n  - [AutoEval Suite](#autoeval-suite)\n    - [Params](#params)\n  - [Citation](#citation)\n\n\n## Environment Setup\n\nSetup instuctions will be shown using `uv`, however any package management system will work. All environments are native to Python 3.10, other versions are untested but may also work.\n\n### Installing `uv`\n\nIf you like the sound of ~50x faster environment setup times, run the following to install `uv`.\n\n```bash\ncurl -LsSf https://astral.sh/uv/install.sh | sh\n\nsource $HOME/.local/bin/env\n```\n\nTo create a Python virtual environment run:\n\n```bash\nuv venv .env --python 3.10\n```\n\nTo activate said environment, run:\n\n```bash\nsource .env/bin/activate\n```\n\n### Serving P2L Setup\n\nTo serve a P2L model first run:\n\n```bash\nuv pip install -r serve_requirements.txt\n```\n\n### Serving a Router Setup\n\nTo serve a OpenAI compatible router, first run:\n\n```bash\nuv pip install -r route/requirements.txt\n```\n\n### Training Setup\n\nTo train a P2L model first run:\n\n```bash\nuv pip install -r train_requirements.txt\n```\n\n## Serving P2L\n\nBefore getting started, make sure you have followed the steps in [Serving Setup](#serving-p2l-setup).\n\n`python p2l.endpoint` considers the following arguments:\n\n| Option | Short Flag | Description |\n|--------|-----------|-------------|\n| `--help` | `-h` | Show this help message and exit. |\n| `--model-path MODEL_PATH` | `-m MODEL_PATH` | Path to the model repository. |\n| `--model-type MODEL_TYPE` | `-mt MODEL_TYPE` | Type of the model. |\n| `--head-type HEAD_TYPE` | `-ht HEAD_TYPE` | Type of model head. |\n| `--loss-type LOSS_TYPE` | `-lt LOSS_TYPE` | Type of the loss function. |\n| `--api-key API_KEY` | `-a API_KEY` | API key for authorization. |\n| `--host HOST` | `-H HOST` | Host to run the server on. |\n| `--port PORT` | `-p PORT` | Port to run the server on. |\n| `--reload, --no-reload` | - | Whether to reload the endpoint on detected code changes (requires workers to be set to 1). |\n| `--workers WORKERS` | - | Number of endpoint workers (each will hold a model instance). |\n| `--cuda, --no-cuda` | - | Flag to enable using a GPU to host the model. Flag is true by default. |\n\nFor example, to run lmarena-ai/p2l-7b-grk-02222025, which is a Qwen2 based \"grk\" model, which has head type `rk`, we would run:\n\n```bash\npython -m p2l.endpoint --model-path lmarena-ai/p2l-7b-grk-02222025 --model-type qwen2 --head-type rk --api-key <your-desired-api-key>\n```\n\nThis code will host the model running on 1 worker and host 0.0.0.0 and port 10250 by default. Reload will be enabled meaning code changes will reload the endpoint. Note that by default the endpoint expects to load the model onto a GPU, however by specifying `--no-cuda` you can run this on CPU only, which may work for smaller P2L models.\n\nEach P2L model has an associated model list, which specifices which model each index of the outputted coefficients corresponds to. Below is an example function to get this model list from the hosted endpoint:\n\n```python\ndef get_p2l_endpoint_models(base_url: str, api_key: str) -> List[str]:\n\n    headers = {\n        \"Content-Type\": \"application/json\",\n        \"api-key\": api_key,\n    }\n\n    try:\n        response = requests.get(f\"{base_url}/models\", headers=headers)\n        response.raise_for_status()\n        result = response.json()\n        return result[\"models\"]\n\n    except Exception as err:\n        print(f\"An error occurred: {err}\")\n```\n\nBelow is an example python function to query the P2L endpoint:\n\n```python\ndef query_p2l_endpoint(\n    prompt: list[str], base_url: str, api_key: str\n) -> Dict[str, List]:\n\n    headers = {\n        \"Content-Type\": \"application/json\",\n        \"api-key\": api_key,\n    }\n\n    payload = {\"prompt\": prompt}\n\n    try:\n        response = requests.post(\n            f\"{base_url}/predict\", headers=headers, data=json.dumps(payload)\n        )\n        response.raise_for_status()\n        result = response.json()\n        return result\n\n    except Exception as err:\n\n        raise err\n```\n\nNote that the input is a list of strings. This is NOT for  a batch of prompts, but rather for each turn in a coversation. For example, given a 2 turn conversation:\n\n```\nUser: \"hi!\"\nAssistant: \"Hello\"\nUser: \"what's 1+1?\"\n```\n\nThe correct P2L input would be:\n\n```python\n[\"hi!\", \"what's 1+1?\"]\n```\n\n## Serving an OpenAI Compatible Router\n\nServe an OpenAI compatible router with `python -m route.openai_server`. The available arguments are shown below.\n\n| Option | Short Flag | Description |\n|--------|-----------|-------------|\n| `--help` | `-h` | Show this help message and exit. |\n| `--config CONFIG` | `-c CONFIG` | Path to the configuration file. |\n| `--router-type ROUTER_TYPE` | - | Type of the router to use. Available types are `bt-endpoint` and `grk-endpoint`.|\n| `--router-model-name ROUTER_MODEL_NAME` | - | Name of the router model. |\n| `--router-model-endpoint ROUTER_MODEL_ENDPOINT` | - | Endpoint URL for the router model. |\n| `--router-api-key ROUTER_API_KEY` | - | API key for the router authentication. |\n| `--cost-optimizer COST_OPTIMIZER` | - | Enable or configure cost optimization settings. Available types are `optimal-lp`, `simple-lp`, `strict`.|\n| `--port PORT` | `-p PORT` | Port to run the server on. |\n| `--host HOST` | - | Host to run the server on. |\n| `--api-key API_KEY` | - | API key for authorization. |\n| `--reload, --no-reload` | - | Whether to reload the endpoint on detected code changes (requires workers to be set to 1). |\n| `--workers WORKERS` | - | Number of endpoint workers (each will hold a model instance). |\n\n### Example: serving a Bradley-Terry based cost-optimal router\n\nFirst, similar to above [above](#serving-p2l), we need to start serving a P2L model, this time Bradley-Terry based. To do this, let's run:\n\n```bash\npython -m p2l.endpoint --model-path lmarena-ai/p2l-7b-bt-01132025 --model-type qwen2 --head-type bt --api-key <your-desired-api-key>\n```\n\nNow, we need to configure a routing config file. This will specify the available models and inference details for the router.\n\nFor example, here is an example configuration that specifies Claude-3.5-Sonnet and GPT-4o:\n\n```yaml\nmodel_configs:\n    claude-3-5-sonnet-20241022:\n        api_key: <your-api-key>\n        base_url: null\n        cost: 9.3110239362\n        max_tokens: 8192\n        name: claude-3-5-sonnet-20241022\n        system_prompt: null\n        temp: 0.7\n        top_p: 0.7\n        type: anthropic\n\n    gpt-4o-2024-05-13:\n        api_key: <your-api-key>\n        base_url: null\n        cost: 12.3166873868\n        name: gpt-4o-2024-05-13\n        system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based\n        on the GPT-4 architecture.\n\n        Current date: 2025-01-06\n\n\n        Image input capabilities: Enabled\n\n        Personality: v2'\n        temp: 0.7\n        top_p: 1.0\n        type: openai\n```\n\nNotice how the system prompt, temperature, and top_p are defined. These replicate how the models are served on Chatbot Arena. P2L is trained with the expectation that the models are running on this configuration. Therefore, for the most reliable results, we recommend sticking to the configs shown in [`example_config.yaml`](./route/example_config.yaml), though alternatives should still function well.\n\nAdditionally, we allow for adjustment of the `cost` parameter. One natural choice is just cost per output token, however more accuracte cost estimates are better. For example, the costs in [`example_config.yaml`](./route/example_config.yaml) are calculated to be proportional to the formula `cost_per_output_token * average_output_tokens_per_response`.\n\nNow, lets assume we put the above config content into `config.yaml`. To start the OpenAI compatible router we would run:\n\n```bash\npython -m route.openai_server --config config.yaml --router-type bt-endpoint --router-model-endpoint http://0.0.0.0:10250 --router-api-key <your-api-key> --cost-optimizer optimal-lp --api-key <your-endpoint-api-key>\n```\n\nLet's break down what this command means:\n\n- `--router-type bt-endpoint`: we are using a Bradley-Terry based P2L model hosted on an endpoint.\n- `--router-model-endpoint http://0.0.0.0:10250`: this is where the router endpoint is, generally the default address will be this if you are running the routing server on the same machine running the P2L endpoint.\n- `--cost-optimizer optimal-lp`: we are using cost routing using the optimal linear program detailed in Theorem 1 of the paper.\n\n>**Note**: `optimal-lp` is only compatible with BT models, and `simple-lp` is only compatible with grounded RK (sometimes specified as bag) models.\n\n\n### Example: serving a Grounded RK based simple cost router\n\nP2L has a class of \"Grounded RK\" models. These models produces coefficents such that `0.0` represents the threshold for a \"usable\" answer. We can leverage this to cost route to maximize $P(\\text{Not Bad})$... whatever that means exactly. Below we detail the steps to run this routing setup.\n\nFirst, start up the P2L endpoint:\n\n```bash\npython -m p2l.endpoint --model-path lmarena-ai/p2l-7b-grk-02222025 --model-type qwen2 --head-type rk --api-key <your-desired-api-key>\n```\n\nThen start up the router server:\n\n```bash\npython -m router.openai_server --config config.yaml --router-type grk-endpoint --router-model-endpoint http://0.0.0.0:10250 --router-api-key <your-api-key> --cost-optimizer simple-lp --api-key <your-endpoint-api-key>\n```\n\n## Calling the OpenAI Compatible Router\n\nAs aptly named, the router server is OpenAI compatible. We can call it like any other OpenAI compatible model:\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(\n    base_url: \"<your_router_endpoint_url>/v1\",\n    api_key: \"<your_router_api_key>\",\n)\n\nprompt = \"what's 828913*1234?\"\n\nresponse = client.chat.completions.create(\n    model=\"-\", # This field is actually not used\n    message=[{\"role\": \"user\", \"content\": prompt}],\n    stream=True, # Router is compatible with and without streaming.\n)\n# Notice no temperature, top_p, or system prompt is set.\n# This allows the router to use the default provided by the config file.\n# If you do pass in these fields, they will override the config.\n```\n\nIf we want to specify a cost budget, we need to do the following:\n\n```python\nresponse = client.chat.completions.create(\n    model=\"-\", # This field is actually not used\n    message=[{\"role\": \"user\", \"content\": prompt}],\n    stream=True, # Router is compatible with and without streaming.\n    extra_body={\"cost\": <desired_cost>}\n)\n```\n\n## Training a P2L Model\n\nThis codebase also contains the training code for P2L models. To train a P2L model, first set up a training config. The [`training_configs`](./training_configs/) directory has many examples.\n\nTo train run, for example:\n\n```bash\ndeepspeed --num_gpus=8 --module p2l.train --config training_configs/<your_config>.yaml --no-eval --save-steps 512\n```\n\n## Inferencing a P2L Model\n\nTo quickly inference on a dataset using P2L, run:\n\n```bash\npython -m p2l.eval --model <p2l_model_name> --dataset <hf_dataset_path> --head-type <head_type> --model-type <qwen2_or_llama> --batch-size 2\n```\n\nThis will work on any dataset of single turn prompts under the column name `prompt`.\n\n## AutoEval Suite\n\nOur in-depth evaluation code can be run using `p2l.auto_evals`.\n\n### Params\n\n- **a. Model List Params**\n    1. Either provide `--model_repo`, which has a `model_list.json` file.\n    2. Or provide a local `--model_list_path` file.\n\n- **b. Val Data**\n    1. **Data is in JSONL format**:\n        - Provide a local `--eval_path`.\n        - If no path is provided, the program will look for an `eval_outputs.jsonl` file in the `--model_repo` on HF.\n    2. **Data is in JSON format (checkpoint files)**:\n        - Provide a local `--checkpoint_path`.\n        - Or provide remote `--hf_checkpoint_repo` and `--hf_checkpoint_file`.\n\n- **c. Output Directory**\n    1. Provide a local `--output_dir` or a remote `--hf_output_dir`.\n    2. Provide `--output_file_name`.\n\n- **d. Train Data (Optional)**\n    - Provide `--hf_train_dataset` or a local `--train_path`.\n\n- **e. Arena Data (Optional)**\n    - Provide a local `--arena_path` (CSV with model rankings).\n\n- **f. Provide Model Info**\n    1. `--loss_type` (e.g., `bt`, `bt_tie`, `rk`).\n    2. `--model_type` (e.g., `p2l`, `marginal`, `arena`, `marginal-gt`).\n    3. `--categories`.\n\n- **g. Provide Types of Metrics**\n    1. `--simple_metrics`, `--category_metrics`, `--rand_subset_metrics`, `--aggr_scale_subset_metrics`.\n    2. Use `--metrics_to_inc` to filter out which of the above metrics to include.\n\n- **h. Random Subset Params**\n    1. `--rand_subset_sizes`: Specify subset sizes.\n    2. `--rand_num_samples`: Specify the number of samples per random subset size.\n\n- **i. Aggregation Subset Params**\n    1. `--aggr_scale_subset_sizes`: Specify subset sizes.\n    2. `--aggr_scale_num_samples`: Specify the number of samples per random subset size.\n    3. `--aggr_scale_gt`: Specify whether to use `marginal-gt` or `arena` as ground truth for categories.\n\n---\n\n## Citation\n\n```\n@misc{frick2025prompttoleaderboard,\n      title={Prompt-to-Leaderboard}, \n      author={Evan Frick and Connor Chen and Joseph Tennyson and Tianle Li and Wei-Lin Chiang and Anastasios N. Angelopoulos and Ion Stoica},\n      year={2025},\n      eprint={2502.14855},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG},\n      url={https://arxiv.org/abs/2502.14855}, \n}\n```\n"
  },
  {
    "path": "deepspeed/zero1.json",
    "content": "{\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n\n    \"fp16\": {\n        \"enabled\": \"auto\"\n    },\n\n    \"gradient_accumulation_steps\": \"auto\",\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 1,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": true,\n    \"zero_optimization\": {\n        \"stage\": 1,\n        \"reduce_bucket_size\": 5e8\n    },\n    \"optimizer\": {\n        \"type\": \"AdamW\",\n        \"params\": {\n          \"lr\": \"auto\",\n          \"betas\": [\n            0.9,\n            0.999\n          ],\n          \"eps\": \"auto\"\n        }\n    }\n}"
  },
  {
    "path": "fast_lambda_setup.sh",
    "content": "sudo apt-get update -y\nsudo apt-get install tmux -y\nsudo apt-get install python3-dev -y\n\nsudo apt-get install tmux libaio-dev libopenmpi-dev python3-mpi4py -y\n\ncurl -LsSf https://astral.sh/uv/install.sh | sh\n\nsource $HOME/.local/bin/env\n\nuv venv .env --python 3.10\n\nsource .env/bin/activate\n\nuv pip install wheel packaging\n\nuv pip install -r train_requirements.txt\nuv pip install flash-attn==2.5.9.post1 --no-build-isolation\n"
  },
  {
    "path": "fast_runpod_setup.sh",
    "content": "apt-get update -y\napt-get install tmux -y\napt-get install python3-dev -y\n\napt-get install tmux libaio-dev libopenmpi-dev python3-mpi4py -y\n\ncurl -LsSf https://astral.sh/uv/install.sh | sh\n\nsource $HOME/.local/bin/env\n\nuv venv .env --python 3.10\n\nsource .env/bin/activate\n\nuv pip install wheel packaging\n\nuv pip install -r train_requirements.txt\nuv pip install flash-attn==2.5.9.post1 --no-build-isolation\n"
  },
  {
    "path": "p2l/auto_eval_utils.py",
    "content": "from typing import Callable, Dict\n\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport pandas as pd\nimport numpy as np\nfrom scipy.optimize import minimize\n\n\nfrom scipy.stats import kendalltau, spearmanr\nfrom model import (\n    registered_losses,\n    HeadOutputs,\n    registered_aggr_models,\n    registered_pairwise_losses,\n)\n\nregistered_simple_metrics: Dict[str, Dict[str, Callable]] = {}\nregistered_aggr_metrics: Dict[str, Dict[str, Callable]] = {}\nregistered_helpers: Dict[str, Callable] = {}\n\n\ndef register_simple_metric(loss_type: str, metric: str):\n    def decorator(func: Callable):\n        if loss_type not in registered_simple_metrics:\n            registered_simple_metrics[loss_type] = {}\n        registered_simple_metrics[loss_type][metric] = func\n        return func\n\n    return decorator\n\n\ndef register_aggr_metric(loss_type: str, metric: str):\n    def decorator(func: Callable):\n        if loss_type not in registered_aggr_metrics:\n            registered_aggr_metrics[loss_type] = {}\n        registered_aggr_metrics[loss_type][metric] = func\n        return func\n\n    return decorator\n\n\ndef register_helper(loss_or_model_type: str, helper_func):\n    def decorator(func: Callable):\n        if loss_or_model_type not in registered_helpers:\n            registered_helpers[loss_or_model_type] = {}\n        registered_helpers[loss_or_model_type][helper_func] = func\n        return func\n\n    return decorator\n\n\n@register_helper(\"p2l\", \"output_labels\")\ndef output_labels_p2l(val_data: pd.DataFrame, **kwargs):\n    betas = torch.tensor(np.stack(val_data[\"betas\"]), dtype=torch.float)\n    labels = torch.tensor(np.stack(val_data[\"labels\"]))\n    etas = None\n\n    if \"eta\" in val_data.columns:\n        etas = torch.tensor(np.stack(val_data[\"eta\"]), dtype=torch.float)\n\n    return HeadOutputs(coefs=betas, eta=etas), labels\n\n\ndef translate_coefs(coef, old_list, new_list):\n    old_list = old_list.tolist()\n    old_to_new = [old_list.index(model) for model in new_list]\n    betas_array = np.array(coef)\n\n    betas_array = betas_array[old_to_new]\n\n    return torch.tensor(betas_array)\n\n\n@register_helper(\"marginal\", \"output_labels\")\ndef output_labels_marginal(\n    val_data: pd.DataFrame,\n    train_data: pd.DataFrame,\n    model_list: np.array,\n    train_model_list: np.array,\n    loss_type: str,\n    **kwargs,\n):\n    train_labels = torch.tensor(np.stack(train_data[\"labels\"]))\n    coefs, eta = train_marginal(train_model_list, train_labels, loss_type)\n    coefs, eta = coefs[0], eta[0] if eta is not None else None\n\n    coefs = translate_coefs(coefs, train_model_list, model_list)\n\n    val_labels = torch.tensor(np.stack(val_data[\"labels\"]))\n\n    coefs = coefs.expand(len(val_labels), -1)\n    eta = eta.expand(len(val_labels), -1) if eta is not None else None\n\n    return HeadOutputs(coefs=coefs, eta=eta), val_labels\n\n\n@register_helper(\"marginal-gt\", \"output_labels\")\ndef output_labels_marginal_gt(\n    val_data: pd.DataFrame, model_list: np.array, loss_type: str, **kwargs\n):\n    val_labels = torch.tensor(np.stack(val_data[\"labels\"]))\n    coefs, eta = train_marginal(model_list, val_labels, loss_type)\n\n    coefs = coefs.expand(len(val_labels), -1)\n    eta = eta.expand(len(val_labels), -1) if eta is not None else None\n\n    return HeadOutputs(coefs=coefs, eta=eta), val_labels\n\n\n@register_helper(\"arena\", \"output_labels\")\ndef output_labels_arena(\n    arena_rankings: torch.tensor, val_data: pd.DataFrame, loss_type: str, **kwargs\n):\n    labels = torch.tensor(np.stack(val_data[\"labels\"]))\n\n    # arena rankings is already filtered so it will be 1d tensor\n    betas = arena_rankings.expand(len(labels), -1)\n    etas = torch.ones(len(labels))\n    etas = etas.unsqueeze(-1)\n\n    # TODO: Cleanup\n    if loss_type == \"bt\" or loss_type == \"bt-tie\":\n        etas = None\n\n    return HeadOutputs(coefs=betas, eta=etas), labels\n\n\n@register_helper(\"bag\", \"preprocess_data\")\ndef preprocess_data_bag(data: pd.DataFrame, **kwargs):\n    condition = data[\"winner\"] == \"tie (bothbad)\"\n    data.loc[condition, \"labels\"] = data.loc[condition, \"labels\"].apply(\n        lambda arr: arr[:2] + [2]\n    )\n    return data\n\n\n@register_helper(\"bt\", \"preprocess_data\")\n@register_helper(\"bt-tie\", \"preprocess_data\")\n@register_helper(\"rk\", \"preprocess_data\")\n@register_helper(\"rk-reparam\", \"preprocess_data\")\ndef preprocess_data(data: pd.DataFrame, **kwargs):\n    return data\n\n\n@register_simple_metric(\"bt\", \"Loss\")\n@register_simple_metric(\"bt\", \"BCELoss\")\n@register_simple_metric(\"bt-tie\", \"Loss\")\n@register_simple_metric(\"rk\", \"Loss\")\n@register_simple_metric(\"rk-reparam\", \"Loss\")\n@register_simple_metric(\"bag\", \"Loss\")\ndef loss(head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs):\n    loss_func = registered_losses.get(loss_type)\n    return loss_func(head_output=head_output, labels=labels).item()\n\n\n@register_simple_metric(\"rk\", \"Tie_Loss\")\n@register_simple_metric(\"bag\", \"Tie_Loss\")\ndef tie_loss(head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs):\n    loss_func = registered_losses.get(\"tie-\" + loss_type)\n    return loss_func(head_output=head_output, labels=labels).item()\n\n\n@register_simple_metric(\"bag\", \"Tie_bb_Loss\")\ndef tie_bb_loss(\n    head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs\n):\n    loss_func = registered_losses.get(\"tie-bb-\" + loss_type)\n    return loss_func(head_output=head_output, labels=labels).item()\n\n\n@register_aggr_metric(\"bt\", \"Aggr_Tie_Loss\")\n@register_aggr_metric(\"bt-tie\", \"Aggr_Tie_Loss\")\n@register_aggr_metric(\"rk\", \"Aggr_Tie_Loss\")\n@register_aggr_metric(\"rk-reparam\", \"Aggr_Tie_Loss\")\n@register_aggr_metric(\"bag\", \"Aggr_Tie_Loss\")\ndef Aggr_Tie_Loss(\n    gt_output: HeadOutputs,\n    model_output: HeadOutputs,\n    loss_type: str,\n    labels: torch.tensor,\n    **kwargs,\n):\n\n    return aggr_metric(\"Tie_Loss\", loss_type, labels, gt_output, model_output)\n\n\n@register_simple_metric(\"bt-tie\", \"BCELoss\")\n@register_simple_metric(\"rk\", \"BCELoss\")\n@register_simple_metric(\"rk-reparam\", \"BCELoss\")\n@register_simple_metric(\"bag\", \"BCELoss\")\ndef BCE_loss(head_output: HeadOutputs, labels: torch.Tensor, **kwargs):\n    non_tie_index = torch.where(labels[:, -1] == 0)[0]\n\n    new_coefs = head_output.coefs[non_tie_index, :]\n    new_eta = head_output.eta[non_tie_index] if head_output.eta is not None else None\n\n    no_tie_output = HeadOutputs(coefs=new_coefs, eta=new_eta)\n    no_tie_labels = labels[non_tie_index, :]\n    return loss(no_tie_output, no_tie_labels, loss_type=\"bt\")\n\n\ndef aggr_metric(metric_name, loss_type, labels, gt_output, model_output):\n    func = registered_simple_metrics[loss_type][metric_name]\n\n    gt = func(\n        labels=labels, head_output=expand_output(gt_output, labels), loss_type=loss_type\n    )\n    model = func(\n        labels=labels,\n        head_output=expand_output(model_output, labels),\n        loss_type=loss_type,\n    )\n\n    return {\"ground-truth\": round(gt, 4), \"model-aggr\": round(model, 4)}\n\n\n@register_aggr_metric(\"bt\", \"Aggr_Loss\")\n@register_aggr_metric(\"bt-tie\", \"Aggr_Loss\")\n@register_aggr_metric(\"rk\", \"Aggr_Loss\")\n@register_aggr_metric(\"rk-reparam\", \"Aggr_Loss\")\n@register_aggr_metric(\"bag\", \"Aggr_Loss\")\ndef Aggr_Loss(\n    gt_output: HeadOutputs,\n    model_output: HeadOutputs,\n    loss_type: str,\n    labels: torch.tensor,\n    **kwargs,\n):\n\n    return aggr_metric(\"Loss\", loss_type, labels, gt_output, model_output)\n\n\n@register_aggr_metric(\"bt\", \"Aggr_BCELoss\")\n@register_aggr_metric(\"bt-tie\", \"Aggr_BCELoss\")\n@register_aggr_metric(\"rk\", \"Aggr_BCELoss\")\n@register_aggr_metric(\"rk-reparam\", \"Aggr_BCELoss\")\n@register_aggr_metric(\"bag\", \"Aggr_BCELoss\")\ndef Aggr_BCE_Loss(\n    gt_output: HeadOutputs,\n    model_output: HeadOutputs,\n    loss_type: str,\n    labels: torch.tensor,\n    **kwargs,\n):\n\n    return aggr_metric(\"BCELoss\", loss_type, labels, gt_output, model_output)\n\n\ndef expand_output(output, labels):\n    coefs, eta = output.coefs, output.eta\n    new_coefs = coefs.expand(len(labels), -1)\n\n    if eta is not None:\n        eta = eta.expand(len(labels), -1)\n    return HeadOutputs(coefs=new_coefs, eta=eta)\n\n\n@register_simple_metric(\"bt\", \"MSELoss\")\ndef BT_mse(\n    head_output: HeadOutputs,\n    labels: torch.Tensor,\n    **kwargs,\n):\n    coefs = head_output.coefs\n    paired_coefs = coefs.gather(dim=-1, index=labels).contiguous()\n\n    paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]\n    predicted_probs = torch.sigmoid(paired_delta_logit)\n    true_labels = torch.ones_like(predicted_probs)\n\n    mse = F.mse_loss(predicted_probs, true_labels)\n    return mse.mean().item()\n\n\n@register_simple_metric(\"bt-tie\", \"MSELoss\")\ndef BT_tie_mst(\n    head_output: HeadOutputs,\n    labels: torch.Tensor,\n    **kwargs,\n):\n    coefs = head_output.coefs\n    model_idx = labels[:, :2]\n\n    paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()\n    paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]\n\n    p_w = torch.sigmoid(paired_delta_logit)\n    tie_ind = labels[:, -1]\n\n    # let label be 0.5 if there is tie\n    pred_probs = torch.where(tie_ind == 1, 0.5, p_w)\n\n    true_labels = torch.ones_like(pred_probs)\n    mse = F.mse_loss(pred_probs, true_labels)\n    return mse.mean().item()\n\n\n@register_simple_metric(\"rk\", \"MSELoss\")\n@register_simple_metric(\"rk-reparam\", \"MSELoss\")\ndef RK_mse(head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs):\n    probs_func = registered_helpers[loss_type][\"probs\"]\n    p_w, _, p_t = probs_func(head_output=head_output, labels=labels)\n\n    tie_ind = labels[:, -1]\n\n    # True label will always be win (since first index) unless a tie occurs\n    pred_probs = torch.where(tie_ind == 1, p_t, p_w)\n\n    true_labels = torch.ones_like(pred_probs)\n    mse = F.mse_loss(pred_probs, true_labels)\n    return mse.mean().item()\n\n\n@register_simple_metric(\"bag\", \"MSELoss\")\ndef bag_mse(head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs):\n    probs_func = registered_helpers[loss_type][\"probs\"]\n    p_w, _, p_t, p_t_bb = probs_func(head_output=head_output, labels=labels)\n\n    tie_ind = labels[:, -1].unsqueeze(-1)\n\n    P = torch.stack([p_w, p_t, p_t_bb], dim=-1)\n\n    pred_probs = P.gather(dim=-1, index=tie_ind).contiguous().squeeze(-1)\n\n    true_labels = torch.ones_like(pred_probs)\n    mse = F.mse_loss(pred_probs, true_labels)\n    return mse.mean().item()\n\n\n@register_helper(\"rk-reparam\", \"probs\")\ndef rk_reparam_probs(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n    coefs = head_output.coefs\n    eta = head_output.eta\n\n    theta = (torch.exp(eta) + 1.000001).squeeze(-1)\n\n    winner_idx = labels[:, 0:1]\n    loser_idx = labels[:, 1:2]\n\n    beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()[:, 0]\n    beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()[:, 0]\n\n    pi_win = torch.exp(beta_win)\n    pi_lose = torch.exp(beta_lose)\n    p_win = pi_win / (pi_win + theta * pi_lose + 1.0)\n\n    p_lose = pi_lose / (pi_lose + theta * pi_win + 1.0)\n\n    p_tie = 1.0 - p_win - p_lose\n    return p_win, p_lose, p_tie\n\n\n@register_helper(\"bag\", \"probs\")\ndef bag_probs(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n    coefs = head_output.coefs\n    eta = head_output.eta\n\n    theta = (torch.exp(eta) + 1.000001).squeeze(-1)\n\n    winner_idx = labels[:, 0:1]\n    loser_idx = labels[:, 1:2]\n\n    beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()[:, 0]\n    beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()[:, 0]\n\n    pi_win = torch.exp(beta_win)\n    pi_lose = torch.exp(beta_lose)\n    pi_gamma = 1.0\n\n    p_win = pi_win / (pi_win + theta * pi_lose + pi_gamma)\n    p_lose = pi_lose / (pi_lose + theta * pi_win + pi_gamma)\n    p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)\n\n    p_tie = 1.0 - p_win - p_lose - p_tie_bb\n    return p_win, p_lose, p_tie, p_tie_bb\n\n\n@register_helper(\"rk\", \"probs\")\ndef rk_probs(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n    coefs = head_output.coefs\n\n    eta = rk_eta(head_output)\n\n    model_idx = labels[:, :2]\n    paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()\n    paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]\n\n    p_w = torch.sigmoid(paired_delta_logit - eta)\n    p_l = torch.sigmoid(-1 * paired_delta_logit - eta)\n    p_t = 1 - p_w - p_l\n\n    return p_w, p_l, p_t\n\n\n@register_simple_metric(\"bt\", \"Accuracy\")\ndef BT_accuracy(\n    head_output: HeadOutputs,\n    labels: torch.Tensor,\n    **kwargs,\n):\n    coefs = head_output.coefs\n    paired_coefs = coefs.gather(dim=-1, index=labels).contiguous()\n    paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]\n\n    # winner would have positive difference\n    correct = (paired_delta_logit > 0).float()\n    return correct.mean().item()\n\n\n@register_simple_metric(\"bt-tie\", \"Accuracy\")\ndef BT_tie_accuracy(\n    head_output: HeadOutputs,\n    labels: torch.Tensor,\n    **kwargs,\n):\n    coefs = head_output.coefs\n    paired_coefs = coefs.gather(dim=-1, index=labels).contiguous()\n\n    paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]\n\n    # winner would have positive difference\n    correct = (paired_delta_logit > 0).float()\n    tie_ind = labels[:, -1]\n    # we give ties half the accuracy\n    correct[tie_ind == 1] = 0.5\n    return correct.mean().item()\n\n\n@register_simple_metric(\"rk\", \"Accuracy\")\n@register_simple_metric(\"rk-reparam\", \"Accuracy\")\ndef RK_accuracy(\n    head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs\n):\n    probs_func = registered_helpers[loss_type][\"probs\"]\n    p_w, p_l, p_t = probs_func(head_output=head_output, labels=labels)\n\n    pred_labels = torch.where(\n        p_w >= p_l, torch.where(p_w >= p_t, 1, 0.5), torch.where(p_l >= p_t, 0, 0.5)\n    )\n\n    tie_ind = labels[:, -1]\n    # tie if tie index, else winner (first index) predicted to win\n    true_labels = torch.where(tie_ind == 1, 0.5, 1)\n\n    correct = (pred_labels == true_labels).float()\n    return correct.mean().item()\n\n\n@register_simple_metric(\"rk\", \"Tie_Accuracy\")\ndef RK_tie_accuracy(\n    head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs\n):\n    probs_func = registered_helpers[loss_type][\"probs\"]\n    p_w, p_l, p_t = probs_func(head_output=head_output, labels=labels)\n\n    p_nt = p_w + p_l\n\n    pred_tie = torch.where(p_t >= p_nt, 1, 0)\n\n    tie_ind = labels[:, -1]\n    correct = (pred_tie == tie_ind).float()\n    return correct.mean().item()\n\n\n@register_simple_metric(\"bag\", \"Tie_Accuracy\")\ndef bag_tie_accuracy(\n    head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs\n):\n    probs_func = registered_helpers[loss_type][\"probs\"]\n    p_w, p_l, p_t, p_t_bb = probs_func(head_output=head_output, labels=labels)\n\n    p_nt = p_w + p_l\n    p_tie = p_t + p_t_bb\n\n    pred_tie = torch.where(p_nt >= p_tie, 0, 1)\n\n    tie_ind = torch.where(labels[:, -1] == 0, 0, 1)\n    correct = (pred_tie == tie_ind).float()\n    return correct.mean().item()\n\n\n@register_simple_metric(\"bag\", \"Tie_bb_Accuracy\")\ndef bag_tie_bb_accuracy(\n    head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs\n):\n    probs_func = registered_helpers[loss_type][\"probs\"]\n    p_w, p_l, p_t, p_t_bb = probs_func(head_output=head_output, labels=labels)\n\n    p_nt_bb = p_w + p_l + p_t\n\n    pred_tie = torch.where(p_t_bb >= p_nt_bb, 1, 0)\n\n    tie_ind = torch.where(labels[:, -1] == 2, 1, 0)\n    correct = (pred_tie == tie_ind).float()\n    return correct.mean().item()\n\n\n@register_aggr_metric(\"bt\", \"Aggr_Tie_Accuracy\")\n@register_aggr_metric(\"bt-tie\", \"Aggr_Tie_Accuracy\")\n@register_aggr_metric(\"rk\", \"Aggr_Tie_Accuracy\")\n@register_aggr_metric(\"rk-reparam\", \"Aggr_Tie_Accuracy\")\n@register_aggr_metric(\"bag\", \"Aggr_Tie_Accuracy\")\ndef Aggr_Tie_accuracy(\n    gt_output: HeadOutputs,\n    model_output: HeadOutputs,\n    loss_type: str,\n    labels: torch.tensor,\n    **kwargs,\n):\n\n    return aggr_metric(\"Tie_Accuracy\", loss_type, labels, gt_output, model_output)\n\n\n@register_aggr_metric(\"bt\", \"Aggr_Tie_Accuracy\")\n@register_aggr_metric(\"bt-tie\", \"Aggr_Tie_Accuracy\")\n@register_aggr_metric(\"rk\", \"Aggr_Tie_Accuracy\")\n@register_aggr_metric(\"rk-reparam\", \"Aggr_Tie_Accuracy\")\n@register_aggr_metric(\"bag\", \"Aggr_Tie_Accuracy\")\ndef Aggr_Tie_accuracy(\n    gt_output: HeadOutputs,\n    model_output: HeadOutputs,\n    loss_type: str,\n    labels: torch.tensor,\n    **kwargs,\n):\n\n    return aggr_metric(\"Tie_Accuracy\", loss_type, labels, gt_output, model_output)\n\n\n@register_aggr_metric(\"bt\", \"Aggr_Tie_bb_Accuracy\")\n@register_aggr_metric(\"bt-tie\", \"Aggr_Tie_bb_Accuracy\")\n@register_aggr_metric(\"rk\", \"Aggr_Tie_bb_Accuracy\")\n@register_aggr_metric(\"rk-reparam\", \"Aggr_Tie_bb_Accuracy\")\n@register_aggr_metric(\"bag\", \"Aggr_Tie_bb_Accuracy\")\ndef Aggr_Tie_bb_accuracy(\n    gt_output: HeadOutputs,\n    model_output: HeadOutputs,\n    loss_type: str,\n    labels: torch.tensor,\n    **kwargs,\n):\n\n    return aggr_metric(\"Tie_bb_Accuracy\", loss_type, labels, gt_output, model_output)\n\n\n@register_aggr_metric(\"bt\", \"Aggr_Tie_bb_Loss\")\n@register_aggr_metric(\"bt-tie\", \"Aggr_Tie_bb_Loss\")\n@register_aggr_metric(\"rk\", \"Aggr_Tie_bb_Loss\")\n@register_aggr_metric(\"rk-reparam\", \"Aggr_Tie_bb_Loss\")\n@register_aggr_metric(\"bag\", \"Aggr_Tie_bb_Loss\")\ndef Aggr_Tie_bb_loss(\n    gt_output: HeadOutputs,\n    model_output: HeadOutputs,\n    loss_type: str,\n    labels: torch.tensor,\n    **kwargs,\n):\n\n    return aggr_metric(\"Tie_bb_Loss\", loss_type, labels, gt_output, model_output)\n\n\n@register_simple_metric(\"rk-reparam\", \"Tie_Accuracy\")\n@register_simple_metric(\"bt\", \"Tie_Accuracy\")\n@register_simple_metric(\"bt-tie\", \"Tie_Accuracy\")\n@register_simple_metric(\"bt\", \"Tie_bb_Loss\")\n@register_simple_metric(\"rk-reparam\", \"Tie_bb_Loss\")\n@register_simple_metric(\"bt-tie\", \"Tie_bb_Loss\")\n@register_simple_metric(\"rk\", \"Tie_bb_Loss\")\n@register_simple_metric(\"bt\", \"Tie_Loss\")\n@register_simple_metric(\"bt-tie\", \"Tie_Loss\")\n@register_simple_metric(\"rk-reparam\", \"Tie_Loss\")\n@register_simple_metric(\"rk\", \"Tie_bb_Accuracy\")\n@register_simple_metric(\"rk-reparam\", \"Tie_bb_Accuracy\")\n@register_simple_metric(\"bt\", \"Tie_bb_Accuracy\")\n@register_simple_metric(\"bt-tie\", \"Tie_bb_Accuracy\")\ndef not_implemented(\n    head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs\n):\n    return -1  # not implemented\n\n\n@register_simple_metric(\"bag\", \"Accuracy\")\ndef bag_accuracy(\n    head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs\n):\n    probs_func = registered_helpers[loss_type][\"probs\"]\n    p_w, p_l, p_t, p_t_bb = probs_func(head_output=head_output, labels=labels)\n\n    P = torch.stack([p_w, p_t, p_t_bb, p_l], dim=-1)\n\n    pred_labels = P.argmax(dim=-1)\n\n    tie_ind = labels[:, -1]\n    # let win be 0, tie be 1, tie_bb be 2. loss never predicted since winner_idx first\n    true_labels = tie_ind\n\n    correct = (pred_labels == true_labels).float()\n    return correct.mean().item()\n\n\n@register_simple_metric(\"bt\", \"Mean-BT\")\n@register_simple_metric(\"bt-tie\", \"Mean-BT\")\n@register_simple_metric(\"rk\", \"Mean-BT\")\n@register_simple_metric(\"rk-reparam\", \"Mean-BT\")\n@register_simple_metric(\"bag\", \"Mean-BT\")\ndef beta_mean(\n    head_output: HeadOutputs,\n    **kwargs,\n):\n    betas = head_output.coefs\n    flat_betas = betas.flatten()\n    return torch.mean(flat_betas).item()\n\n\n@register_simple_metric(\"bt\", \"Std-BT\")\n@register_simple_metric(\"bt-tie\", \"Std-BT\")\n@register_simple_metric(\"rk\", \"Std-BT\")\n@register_simple_metric(\"rk-reparam\", \"Std-BT\")\n@register_simple_metric(\"bag\", \"Std-BT\")\ndef beta_std(\n    head_output: HeadOutputs,\n    **kwargs,\n):\n    betas = head_output.coefs\n    flat_betas = betas.flatten()\n    return torch.std(flat_betas).item()\n\n\n@register_simple_metric(\"bt\", \"Spread-BT\")\n@register_simple_metric(\"bt-tie\", \"Spread-BT\")\n@register_simple_metric(\"rk\", \"Spread-BT\")\n@register_simple_metric(\"rk-reparam\", \"Spread-BT\")\n@register_simple_metric(\"bag\", \"Spread-BT\")\ndef beta_spread(\n    head_output: HeadOutputs,\n    **kwargs,\n):\n    betas = head_output.coefs\n    flat_betas = betas.flatten()\n    return (torch.max(flat_betas) - torch.min(flat_betas)).item()\n\n\n@register_simple_metric(\"bt\", \"Mean-Spread-BT\")\n@register_simple_metric(\"bt-tie\", \"Mean-Spread-BT\")\n@register_simple_metric(\"rk\", \"Mean-Spread-BT\")\n@register_simple_metric(\"rk-reparam\", \"Mean-Spread-BT\")\n@register_simple_metric(\"bag\", \"Mean-Spread-BT\")\ndef beta_mean_spread(\n    head_output: HeadOutputs,\n    **kwargs,\n):\n    betas = head_output.coefs\n    max_min_per_prompt = (\n        torch.max(betas, dim=-1).values - torch.min(betas, dim=-1).values\n    )\n    return torch.mean(max_min_per_prompt).item()\n\n\n@register_simple_metric(\"bt\", \"Mean-IQR-BT\")\n@register_simple_metric(\"bt-tie\", \"Mean-IQR-BT\")\n@register_simple_metric(\"rk\", \"Mean-IQR-BT\")\n@register_simple_metric(\"rk-reparam\", \"Mean-IQR-BT\")\n@register_simple_metric(\"bag\", \"Mean-IQR-BT\")\ndef beta_mean_iqr(\n    head_output: HeadOutputs,\n    **kwargs,\n):\n    betas = head_output.coefs\n    iqr_per_prompt = torch.quantile(betas, 0.75, dim=-1) - torch.quantile(\n        betas, 0.25, dim=-1\n    )\n    return torch.mean(iqr_per_prompt).item()\n\n\n@register_simple_metric(\"bt\", \"Mean-Std-BT\")\n@register_simple_metric(\"bt-tie\", \"Mean-Std-BT\")\n@register_simple_metric(\"rk\", \"Mean-Std-BT\")\n@register_simple_metric(\"rk-reparam\", \"Mean-Std-BT\")\n@register_simple_metric(\"bag\", \"Mean-Std-BT\")\ndef beta_mean_std(\n    head_output: HeadOutputs,\n    **kwargs,\n):\n    betas = head_output.coefs\n    std_per_prompt = torch.std(betas, dim=-1)\n    return torch.mean(std_per_prompt).item()\n\n\n@register_helper(\"marginal-gt\", \"aggregrate\")\ndef aggr_marginal_gt(\n    labels: torch.Tensor, model_list: torch.Tensor, loss_type: str, **kwargs\n):\n    coefs, eta = train_marginal(model_list, labels, loss_type)\n    return HeadOutputs(coefs=coefs[0], eta=eta[0] if eta is not None else None)\n\n\n@register_helper(\"p2l\", \"aggregrate\")\ndef aggr_p2l(\n    head_output: HeadOutputs,\n    labels: torch.Tensor,\n    model_list: torch.Tensor,\n    loss_type: str,\n    **kwargs,\n):\n    coefs, eta = train_aggr_prob(\n        model_list, head_output, labels, loss_type, is_batch=False\n    )\n    return HeadOutputs(coefs=coefs[0], eta=eta[0] if eta is not None else None)\n\n\n@register_helper(\"p2l\", \"aggregrate-batch\")\ndef aggr_p2l_batch(\n    head_output: HeadOutputs,\n    labels: torch.Tensor,\n    model_list: torch.Tensor,\n    loss_type: str,\n    **kwargs,\n):\n    coefs_batch, eta_batch = train_aggr_prob(\n        model_list, head_output, labels, loss_type, is_batch=True\n    )\n    return [\n        HeadOutputs(\n            coefs=coefs_batch[i], eta=eta_batch[i] if eta_batch is not None else None\n        )\n        for i in range(coefs_batch.shape[0])\n    ]\n\n\n@register_helper(\"marginal-gt\", \"aggregrate-batch\")\ndef aggr_p2l_batch(\n    head_output: HeadOutputs,\n    labels: torch.Tensor,\n    model_list: torch.Tensor,\n    loss_type: str,\n    **kwargs,\n):\n    # TODO: Make faster if necessary\n    return [\n        aggr_marginal_gt(labels[i], model_list, loss_type) for i in range(len(labels))\n    ]\n\n\n@register_helper(\"marginal\", \"aggregrate\")\ndef aggr_non_p2l(head_output: HeadOutputs, loss_type: str, **kwargs):\n    etas = head_output.eta\n    etas = etas[0, :] if etas is not None else None\n    return HeadOutputs(coefs=head_output.coefs[0, :], eta=etas)\n\n\n@register_helper(\"arena\", \"aggregrate\")\ndef aggr_non_p2l(\n    head_output: HeadOutputs = None, arena_rankings: torch.tensor = None, **kwargs\n):\n    eta = torch.tensor([0])\n\n    if arena_rankings is not None:\n        return HeadOutputs(coefs=arena_rankings, eta=eta)\n    # arena just has the same betas repeated if not provided\n    return HeadOutputs(coefs=head_output.coefs[0, :], eta=eta)\n\n\ndef train_marginal(model_list, labels, loss_type, lr=1.0, tol=1e-9, max_epochs=50):\n    model_cls = registered_aggr_models[loss_type]\n    model = model_cls(len(model_list))\n\n    optimizer = optim.LBFGS(\n        model.parameters(),\n        lr=lr,\n        max_iter=max_epochs,\n        tolerance_grad=tol,\n        tolerance_change=tol,\n    )\n\n    loss_func = registered_losses[loss_type]\n    labels = (\n        labels.squeeze() if labels.dim() > 2 else labels\n    )  # marginal doesn't use batching since one at a time\n\n    def closure():\n        optimizer.zero_grad()\n        coefs, eta = model()\n\n        coefs_expanded = coefs[0].expand(len(labels), -1)\n        eta_expanded = eta[0].expand(len(labels), -1) if eta is not None else None\n\n        head_output = HeadOutputs(coefs=coefs_expanded, eta=eta_expanded)\n        loss = loss_func(head_output=head_output, labels=labels)\n        loss.backward()\n        return loss\n\n    optimizer.step(closure)\n\n    true_coefs, true_eta = model()\n    return true_coefs.detach(), true_eta.detach() if true_eta is not None else None\n\n\ndef train_aggr_prob(\n    model_list,\n    head_outputs,\n    labels,\n    loss_type,\n    is_batch,\n    lr=1.0,\n    tol=1e-9,\n    max_epochs=50,\n):\n    true_probs_func = registered_helpers[loss_type][\"pairwise_probs\"]\n    true_probs = true_probs_func(real_output=head_outputs)\n    # add a batch size of 1 since aggregration is done in batches (only necessary if data isn't in batch format)\n    if not is_batch:\n        true_probs = true_probs.unsqueeze(0)\n\n    batch_size = true_probs.shape[0]\n    model_cls = registered_aggr_models[loss_type]\n    model = model_cls(len(model_list), batch_size)\n\n    optimizer = optim.LBFGS(\n        model.parameters(),\n        lr=lr,\n        max_iter=max_epochs,\n        tolerance_grad=tol,\n        tolerance_change=tol,\n    )\n    loss_func = registered_pairwise_losses[loss_type]\n\n    count = 0\n    prev_loss = 0\n\n    def closure():\n        optimizer.zero_grad()\n        coefs, eta = model()\n        aggr_output = HeadOutputs(coefs=coefs, eta=eta)\n        loss = loss_func(\n            real_output=head_outputs,\n            aggregated_output=aggr_output,\n            true_probs=true_probs,\n        )\n        loss.backward()\n\n        nonlocal count\n        count += 1\n        if count == 49:\n            raise Warning(\"Batch training did not converge\")\n\n        return loss\n\n    optimizer.step(closure)\n\n    true_coefs, true_eta = model()\n    return true_coefs.detach(), true_eta.detach() if true_eta is not None else None\n\n\ndef rk_eta(output):\n    if output.eta is None:\n        return None\n    BETA = 0.1\n    return torch.clamp(\n        torch.nn.functional.softplus(output.eta - 22.5, BETA).squeeze(-1), min=0.02\n    )\n\n\n@register_helper(\"rk\", \"pairwise_probs\")\ndef pairwise_RK_probs(real_output: HeadOutputs):\n\n    real_betas = real_output.coefs\n    real_eta = rk_eta(real_output)\n    real_eta = real_eta.unsqueeze(-1)\n\n    num_models = real_betas.shape[-1]\n\n    pair_indices = torch.tensor(\n        [(i, j) for i in range(num_models) for j in range(i + 1, num_models)],\n        dtype=torch.long,\n    )\n\n    # elipses allow for both batched/unbatched\n    beta_i_real = real_betas[..., pair_indices[:, 0]]\n    beta_j_real = real_betas[..., pair_indices[:, 1]]\n\n    true_probs_win = torch.sigmoid(beta_i_real - beta_j_real - real_eta)\n    true_probs_loss = torch.sigmoid(beta_j_real - beta_i_real - real_eta)\n    true_probs_tie = 1.0 - true_probs_win - true_probs_loss\n\n    true_probs = torch.stack((true_probs_win, true_probs_loss, true_probs_tie), dim=-1)\n    return true_probs\n\n\n@register_helper(\"rk-reparam\", \"pairwise_probs\")\ndef pairwise_RK_reparam_probs(real_output: HeadOutputs, **kwargs):\n    real_betas = real_output.coefs\n    real_theta = torch.exp(real_output.eta) + 1.000001\n\n    num_models = real_betas.shape[-1]\n\n    pair_indices = torch.tensor(\n        [(i, j) for i in range(num_models) for j in range(i + 1, num_models)],\n        dtype=torch.long,\n    )\n\n    beta_i_real = real_betas[..., pair_indices[:, 0]]\n    beta_j_real = real_betas[..., pair_indices[:, 1]]\n\n    pi_win = torch.exp(beta_i_real)\n    pi_lose = torch.exp(beta_j_real)\n\n    p_win = pi_win / (pi_win + real_theta * pi_lose + 1.0)\n    p_lose = pi_lose / (pi_lose + real_theta * pi_win + 1.0)\n    p_tie = 1.0 - p_win - p_lose\n\n    true_probs = torch.stack((p_win, p_lose, p_tie), dim=-1)\n    return true_probs\n\n\n@register_helper(\"bag\", \"pairwise_probs\")\ndef pairwise_bag_probs(real_output: HeadOutputs, **kwargs):\n    real_betas = real_output.coefs\n    real_theta = torch.exp(real_output.eta) + 1.000001\n\n    num_models = real_betas.shape[-1]\n\n    pair_indices = torch.tensor(\n        [(i, j) for i in range(num_models) for j in range(i + 1, num_models)],\n        dtype=torch.long,\n    )\n\n    beta_i_real = real_betas[..., pair_indices[:, 0]]\n    beta_j_real = real_betas[..., pair_indices[:, 1]]\n\n    pi_win = torch.exp(beta_i_real)\n    pi_lose = torch.exp(beta_j_real)\n    pi_gamma = 1.0\n\n    p_win = pi_win / (pi_win + real_theta * pi_lose + pi_gamma)\n\n    p_lose = pi_lose / (pi_lose + real_theta * pi_win + pi_gamma)\n\n    p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)\n\n    p_tie = 1.0 - p_win - p_lose - p_tie_bb\n\n    true_probs = torch.stack((p_win, p_lose, p_tie, p_tie_bb), dim=-1)\n    return true_probs\n\n\n@register_helper(\"bt\", \"pairwise_probs\")\n@register_helper(\"bt-tie\", \"pairwise_probs\")\ndef pairwise_BT_probs(real_output: HeadOutputs):\n    real_betas = real_output.coefs\n\n    num_models = real_betas.shape[-1]\n\n    pair_indices = torch.tensor(\n        [(i, j) for i in range(num_models) for j in range(i + 1, num_models)],\n        dtype=torch.long,\n    )\n\n    beta_i_real = real_betas[..., pair_indices[:, 0]]\n    beta_j_real = real_betas[..., pair_indices[:, 1]]\n\n    true_probs = torch.sigmoid(beta_i_real - beta_j_real)\n    return true_probs\n\n\n# removes nan from tensor, indices will be shifted\ndef remove_beta_nan(beta1, beta2):\n    beta_mask = ~torch.isnan(beta1) & ~torch.isnan(beta2)\n    return beta1[beta_mask], beta2[beta_mask]\n\n\n@register_aggr_metric(\"bt\", \"Leaderboard\")\n@register_aggr_metric(\"bt-tie\", \"Leaderboard\")\n@register_aggr_metric(\"rk\", \"Leaderboard\")\n@register_aggr_metric(\"rk-reparam\", \"Leaderboard\")\n@register_aggr_metric(\"bag\", \"Leaderboard\")\ndef leaderboard(\n    gt_output: HeadOutputs, model_output: HeadOutputs, model_list: np.array, **kwargs\n):\n    gt_lb = get_leaderboard(gt_output, model_list)\n    model_lb = get_leaderboard(model_output, model_list)\n\n    return {\"ground-truth\": list(gt_lb), \"model-aggr\": list(model_lb)}\n\n\ndef get_leaderboard(output, model_list):\n    coefs = output.coefs\n\n    sorted_indices = torch.argsort(coefs, descending=True)\n    sorted_model_names = [model_list[i] for i in sorted_indices]\n    sorted_betas = coefs[sorted_indices]\n\n    leaderboard = []\n    for i in range(len(sorted_model_names)):\n        beta = (\n            round(sorted_betas[i].item(), 4)\n            if not torch.isnan(sorted_betas[i])\n            else \"nan\"\n        )\n        cur_model = str(sorted_model_names[i]) + \": \" + str(beta)\n        leaderboard.append(cur_model)\n\n    return np.array(leaderboard)\n\n\n@register_aggr_metric(\"bt\", \"L1-Dist-Prob\")\n@register_aggr_metric(\"bt-tie\", \"L1-Dist-Prob\")\ndef l1_dist_prob_bt(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):\n    beta1 = gt_output.coefs\n    beta2 = model_output.coefs\n\n    # if arena is one, there may be nan if model not present in that file\n    beta1, beta2 = remove_beta_nan(beta1, beta2)\n\n    diff_matrix1 = beta1.unsqueeze(1) - beta1.unsqueeze(0)\n    diff_matrix2 = beta2.unsqueeze(1) - beta2.unsqueeze(0)\n\n    prob_vec1 = torch.sigmoid(diff_matrix1).flatten()\n    prob_vec2 = torch.sigmoid(diff_matrix2).flatten()\n\n    return torch.abs(prob_vec2 - prob_vec1).mean().item()\n\n\n@register_aggr_metric(\"rk-reparam\", \"L1-Dist-Prob\")\n@register_aggr_metric(\"rk\", \"L1-Dist-Prob\")\ndef l1_dist_prob_rk(\n    gt_output: HeadOutputs, model_output: HeadOutputs, loss_type: str, **kwargs\n):\n    eta1 = gt_output.eta\n    eta2 = model_output.eta\n    # need to both have eta\n    if eta1 is None or eta2 is None:\n        return l1_dist_prob_bt(gt_output, model_output)\n\n    pair_probs_func = registered_helpers[loss_type][\"pairwise_probs\"]\n\n    p_win1, p_lose1, p_tie1 = torch.unbind(pair_probs_func(gt_output), -1)\n    p_win2, p_lose2, p_tie2 = torch.unbind(pair_probs_func(model_output), -1)\n\n    win_diff = torch.abs(p_win1 - p_win2).mean().item()\n    lose_diff = torch.abs(p_lose1 - p_lose2).mean().item()\n    tie_diff = torch.abs(p_tie1 - p_tie2).mean().item()\n    return (win_diff + lose_diff + tie_diff) / 3\n\n\n@register_aggr_metric(\"bag\", \"L1-Dist-Prob\")\ndef l1_dist_prob_bag(\n    gt_output: HeadOutputs, model_output: HeadOutputs, loss_type: str, **kwargs\n):\n    eta1 = gt_output.eta\n    eta2 = model_output.eta\n    # need to both have eta\n    if eta1 is None or eta2 is None:\n        return l1_dist_prob_bt(gt_output, model_output)\n\n    pair_probs_func = registered_helpers[loss_type][\"pairwise_probs\"]\n\n    p_win1, p_lose1, p_tie1, p_tie_bb1 = torch.unbind(pair_probs_func(gt_output), -1)\n    p_win2, p_lose2, p_tie2, p_tie_bb2 = torch.unbind(pair_probs_func(model_output), -1)\n\n    win_diff = torch.abs(p_win1 - p_win2).mean().item()\n    lose_diff = torch.abs(p_lose1 - p_lose2).mean().item()\n    tie_diff = torch.abs(p_tie1 - p_tie2).mean().item()\n    tie_bb_diff = torch.abs(p_tie_bb2 - p_tie_bb1).mean().item()\n    return (win_diff + lose_diff + tie_diff + tie_bb_diff) / 4\n\n\n@register_aggr_metric(\"bt\", \"IQR-BT\")\n@register_aggr_metric(\"bt-tie\", \"IQR-BT\")\n@register_aggr_metric(\"rk\", \"IQR-BT\")\n@register_aggr_metric(\"rk-reparam\", \"IQR-BT\")\n@register_aggr_metric(\"bag\", \"IQR-BT\")\ndef beta_iqr(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):\n    (\n        gt_coefs,\n        model_coefs,\n    ) = (\n        gt_output.coefs,\n        model_output.coefs,\n    )\n    gt_iqr = (torch.quantile(gt_coefs, 0.75) - torch.quantile(gt_coefs, 0.25)).item()\n    model_iqr = (\n        torch.quantile(model_coefs, 0.75) - torch.quantile(model_coefs, 0.25)\n    ).item()\n\n    return {\"ground-truth\": round(gt_iqr, 4), \"model-aggr\": round(model_iqr, 4)}\n\n\n@register_aggr_metric(\"bt\", \"Std-BT\")\n@register_aggr_metric(\"bt-tie\", \"Std-BT\")\n@register_aggr_metric(\"rk\", \"Std-BT\")\n@register_aggr_metric(\"rk-reparam\", \"Std-BT\")\n@register_aggr_metric(\"bag\", \"Std-BT\")\ndef beta_std_aggr(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):\n\n    gt_betas, model_betas = gt_output.coefs, model_output.coefs\n    gt_std, model_std = (\n        torch.std(gt_betas.flatten()).item(),\n        torch.std(model_betas.flatten()).item(),\n    )\n    return {\"ground-truth\": round(gt_std, 4), \"model-aggr\": round(model_std, 4)}\n\n\n@register_aggr_metric(\"bt\", \"Spread-BT\")\n@register_aggr_metric(\"bt-tie\", \"Spread-BT\")\n@register_aggr_metric(\"rk\", \"Spread-BT\")\n@register_aggr_metric(\"rk-reparam\", \"Spread-BT\")\n@register_aggr_metric(\"bag\", \"Spread-BT\")\ndef beta_spread_aggr(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):\n    gt_betas, model_betas = gt_output.coefs.flatten(), model_output.coefs.flatten()\n\n    gt_spread, model_spread = torch.max(gt_betas) - torch.min(gt_betas), torch.max(\n        model_betas\n    ) - torch.min(model_betas)\n    return {\n        \"ground-truth\": round(gt_spread.item(), 4),\n        \"model-aggr\": round(model_spread.item(), 4),\n    }\n\n\n@register_aggr_metric(\"bt\", \"Kendall-lbs\")\n@register_aggr_metric(\"bt-tie\", \"Kendall-lbs\")\n@register_aggr_metric(\"rk\", \"Kendall-lbs\")\n@register_aggr_metric(\"rk-reparam\", \"Kendall-lbs\")\n@register_aggr_metric(\"bag\", \"Kendall-lbs\")\ndef kendall_lb(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):\n    gt_betas, model_betas = remove_beta_nan(gt_output.coefs, model_output.coefs)\n    gt_lb = gt_betas.numpy()\n    model_lb = model_betas.numpy()\n\n    return kendalltau(gt_lb, model_lb)[0]\n\n\n@register_aggr_metric(\"bt\", \"Spearman-lbs\")\n@register_aggr_metric(\"bt-tie\", \"Spearman-lbs\")\n@register_aggr_metric(\"rk\", \"Spearman-lbs\")\n@register_aggr_metric(\"rk-reparam\", \"Spearman-lbs\")\n@register_aggr_metric(\"bag\", \"Spearman-lbs\")\ndef spearman_lb(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):\n    gt_betas, model_betas = remove_beta_nan(gt_output.coefs, model_output.coefs)\n    gt_lb = gt_betas.numpy()\n    model_lb = model_betas.numpy()\n\n    return spearmanr(gt_lb, model_lb)[0]\n\n\ndef top_k_frac(gt_betas: torch.tensor, model_betas: torch.tensor, k: int):\n    gt_top_indices = set(torch.topk(gt_betas, k).indices.numpy())\n    model_top_indices = set(torch.topk(model_betas, k).indices.numpy())\n    common_indices = gt_top_indices & model_top_indices\n\n    return len(common_indices) / k\n\n\ndef top_k_displace(gt_betas: torch.tensor, model_betas: torch.tensor, k: int):\n    gt_top_indices = torch.topk(gt_betas, k).indices\n    model_ranks = torch.argsort(torch.argsort(model_betas, descending=True))\n    displacements = torch.abs(model_ranks[gt_top_indices] - torch.arange(k))\n\n    return displacements.float().mean().item()\n\n\n@register_aggr_metric(\"bt\", \"Top-k-fraction\")\n@register_aggr_metric(\"bt-tie\", \"Top-k-fraction\")\n@register_aggr_metric(\"rk\", \"Top-k-fraction\")\n@register_aggr_metric(\"rk-reparam\", \"Top-k-fraction\")\n@register_aggr_metric(\"bag\", \"Top-k-fraction\")\ndef top_k_frac_dict(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):\n    gt_betas, model_betas = remove_beta_nan(gt_output.coefs, model_output.coefs)\n\n    res = {}\n    for k in [1, 3, 5, 10]:\n        res[k] = round(top_k_frac(gt_betas, model_betas, k), 4)\n\n    return res\n\n\n@register_aggr_metric(\"bt\", \"Top-k-displace\")\n@register_aggr_metric(\"bt-tie\", \"Top-k-displace\")\n@register_aggr_metric(\"rk\", \"Top-k-displace\")\n@register_aggr_metric(\"rk-reparam\", \"Top-k-displace\")\n@register_aggr_metric(\"bag\", \"Top-k-displace\")\ndef top_k_dist_dict(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):\n    gt_betas, model_betas = remove_beta_nan(gt_output.coefs, model_output.coefs)\n\n    res = {}\n    for k in [1, 3, 5, 10]:\n        res[k] = round(top_k_displace(gt_betas, model_betas, k), 4)\n\n    return res\n"
  },
  {
    "path": "p2l/auto_evals.py",
    "content": "import argparse\nimport json\nimport os\nimport io\nimport warnings\nimport math\nfrom tqdm import tqdm\nimport time\nimport copy\n\nimport torch\nimport pandas as pd\nimport numpy as np\nfrom datasets import load_dataset, load_from_disk\nfrom huggingface_hub import hf_hub_download, upload_file, list_repo_files\n\nfrom model import HeadOutputs\nfrom auto_eval_utils import (\n    registered_simple_metrics,\n    registered_aggr_metrics,\n    registered_helpers,\n)\n\n\ndef parse_model_list(hf_model, local_path):\n    if not hf_model and not local_path:\n        raise ValueError(\"Either model repo or local model list must be provided.\")\n\n    model_list_path = local_path\n    # if no local path, try getting from model_repo\n    if not model_list_path:\n        model_list_path = hf_hub_download(\n            repo_id=hf_model, filename=\"model_list.json\", repo_type=\"model\"\n        )\n\n    model_list = pd.read_json(model_list_path, lines=False).iloc[:, 0].tolist()\n    return np.array(model_list)\n\n\ndef change_beta_model_list(df, old_list, new_list):\n    old_list = old_list.tolist()\n    old_to_new = [old_list.index(model) for model in new_list]\n    betas_array = np.array(df[\"betas\"].to_list())\n\n    betas_array = betas_array[:, old_to_new]\n    return betas_array.tolist()\n\n\ndef parse_eval_output_data(\n    model_repo,\n    local_eval_path,\n    local_checkpoint_path,\n    hf_checkpoint_repo,\n    hf_checkpoint_file,\n    loss_type,\n    model_list,\n    remove_last_hidden_json,\n):\n    ret_df, ret_model_list = None, None\n    if local_checkpoint_path or hf_checkpoint_repo:\n        path = local_checkpoint_path\n        if not path:\n            if not hf_checkpoint_file:\n                raise ValueError(\n                    \"Must provide checkpoint file along with checkpoint repo\"\n                )\n            path = hf_hub_download(\n                repo_id=hf_checkpoint_repo,\n                filename=hf_checkpoint_file,\n                repo_type=\"dataset\",\n            )\n\n        df = pd.read_json(path)\n\n        # caching json w/o last hidden layer\n        if remove_last_hidden_json and local_checkpoint_path:\n            if \"last_hidden_state\" in df.columns:\n                df = df.drop(columns=[\"last_hidden_state\"])\n                df.to_json(local_checkpoint_path)\n\n        df = df.rename(columns={\"coefs\": \"betas\"})\n\n        # data is stored with nested lists for both etas and betas only in checkpoint data\n        # df['eta'] = np.array(df['eta'].to_list()).flatten()\n        df[\"eta\"] = df[\"eta\"].apply(lambda x: x[0] if isinstance(x, list) else x)\n        df[\"betas\"] = df[\"betas\"].apply(lambda x: x[0] if isinstance(x, list) else x)\n\n        val_model_list = get_model_list_from_df(df)\n        # only betas need to be adjusted since labels are correct\n        df[\"betas\"] = change_beta_model_list(df, model_list, val_model_list)\n\n        ret_df, ret_model_list = df, val_model_list\n\n    elif local_eval_path:\n        ret_df, ret_model_list = pd.read_json(local_eval_path, lines=True), model_list\n\n    elif model_repo:\n        files = list_repo_files(repo_id=model_repo, repo_type=\"model\")\n        if \"eval_output.jsonl\" not in files:\n            raise FileNotFoundError(\n                f\"'eval_output.jsonl' not found in the hf repository'{model_repo}'.\"\n            )\n        path = hf_hub_download(\n            repo_id=model_repo, filename=\"eval_output.jsonl\", repo_type=\"model\"\n        )\n        ret_df, ret_model_list = pd.read_json(path, lines=True), model_list\n    else:\n        raise ValueError(\"need to provide path for eval output data\")\n\n    preprocess_func = registered_helpers[loss_type][\"preprocess_data\"]\n    ret_df = preprocess_func(data=ret_df)\n\n    return ret_df, ret_model_list\n\n\ndef add_labels_to_data(data, loss_type, model_list):\n    if loss_type == \"bt\":\n        data = data[~data[\"winner\"].isin([\"tie\", \"tie (bothbad)\"])]\n\n    def create_labels(row):\n        winner = row[\"winner\"]\n        model_a = row[\"model_a\"]\n        model_b = row[\"model_b\"]\n\n        model_a_idx = np.where(model_list == model_a)[0][0]\n        model_b_idx = np.where(model_list == model_b)[0][0]\n\n        tie_bb_label = 2 if loss_type == \"bag\" else 1\n        if winner == \"model_a\":\n            return np.array([model_a_idx, model_b_idx, 0])\n        elif winner == \"model_b\":\n            return np.array([model_b_idx, model_a_idx, 0])\n        elif winner == \"tie\":\n            return np.array([model_a_idx, model_b_idx, 1])\n        else:\n            return np.array([model_a_idx, model_b_idx, tie_bb_label])\n\n    data[\"labels\"] = data.apply(create_labels, axis=1)\n    return data\n\n\n# only use if completely necessary\ndef get_model_list_from_df(df):\n    return np.array(sorted(pd.concat([df[\"model_a\"], df[\"model_b\"]]).unique()))\n\n\ndef parse_train_data(hf_data, local_path, loss_type, train_model_list):\n    if not hf_data and not local_path:\n        warnings.warn(\n            \"No train data provided, marginal model type will not work if specified\"\n        )\n        return\n\n    if local_path:\n        if local_path.endswith(\".jsonl\"):\n            data = pd.read_json(local_path, lines=True)\n\n        else:\n            data = load_from_disk(local_path)[\"train\"].to_pandas()\n    else:\n        data = load_dataset(hf_data, split=\"train\").to_pandas()\n\n    return add_labels_to_data(data, loss_type, train_model_list)\n\n\ndef parse_arena_data(path, initial_rating=1000, BASE=10, SCALE=400):\n    if not path:\n        warnings.warn(\"Ground truth arena data not passed in, some metrics not work\")\n        return\n\n    df = pd.read_csv(path)\n    # removes to avoid duplicates since not every model has a style_controlled ranking\n    df = df[df[\"style_control\"] == False]\n    # ELO to beta using what eval_p2l.ipynb used\n    df[\"beta\"] = (df[\"rating\"] - initial_rating) / (SCALE * math.log(BASE))\n\n    pivot = df.pivot(index=\"model_name\", columns=\"category\", values=\"beta\").reindex(\n        model_list\n    )\n\n    if pivot.isnull().any().any():\n        missing_models = pivot[pivot.isnull().any(axis=1)].index.tolist()\n        warnings.warn(\"Model not included in arena leaderboard:\" + str(missing_models))\n\n    category_to_betas = {\n        category: torch.tensor(pivot[category].values, dtype=torch.float)\n        for category in pivot.columns\n    }\n    return category_to_betas\n\n\n# NOTE: Only accepts certain categories, needs to be manually added\ndef filter_battle_data(battles, category):\n    if battles is None:\n        return None\n    # expect category key by itself or key=value\n    key_val_pair = category.split(\"=\")\n    key = key_val_pair[0]\n    val = key_val_pair[1] if len(key_val_pair) == 2 else True\n    val = bool(val) if val in [\"True\", \"true\", \"False\", \"false\"] else val\n\n    try:\n        # no filtering\n        if key == \"all\":\n            return battles\n        # no nesting\n        if key == \"language\" or key == \"is_code\":\n            return battles[battles[key] == val]\n        # nested ones need specific cases\n        if key == \"math\":\n            return battles[\n                battles[\"category_tag\"].apply(lambda x: x[\"math_v0.1\"][\"math\"])\n            ]\n        if key == \"complexity\":\n            return battles[\n                battles[\"category_tag\"].apply(\n                    lambda x: x[\"criteria_v0.1\"][\"complexity\"]\n                )\n            ]\n        if key == \"creative_writing\":\n            return battles[\n                battles[\"category_tag\"].apply(\n                    lambda x: x[\"creative_writing_v0.1\"][\"creative_writing\"]\n                )\n            ]\n        if key == \"hard\":\n            return battles[\n                battles[\"category_tag\"].apply(\n                    lambda x: sum(x[\"criteria_v0.1\"].values()) >= 6\n                )\n            ]\n\n        # Category not found\n        return None\n    except:\n        return None\n\n\n# NOTE: Only accepts certain categories, needs to be manually added\ndef get_arena_rankings(data, category):\n    if data is None:\n        return None\n\n    key_val_pair = category.split(\"=\")\n    key = key_val_pair[0]\n    val = key_val_pair[1] if len(key_val_pair) == 2 else True\n    val = bool(val) if val in [\"True\", \"true\", \"False\", \"false\"] else val\n\n    try:\n        # no filtering\n        if key == \"all\":\n            return data[\"full\"]\n        # no nesting\n        if key == \"language\":\n            return data[val.lower()]\n        if key == \"is_code\":\n            return data[\"coding\"]\n        if key == \"math\":\n            return data[\"math\"]\n        if key == \"hard\":\n            return data[\"hard_6\"]\n        if key == \"creative_writing\":\n            return data[\"creative_writing\"]\n\n        return None\n    except:\n        return None\n\n\ndef get_subset_prompts(output, labels, size):\n    num_prompts = output.coefs.shape[0]\n    sampled_indices = torch.randperm(num_prompts)[:size]\n    sampled_coefs = output.coefs[sampled_indices, :]\n\n    sampled_eta = None\n    if output.eta is not None:\n        sampled_eta = output.eta[sampled_indices]\n\n    sampled_labels = labels[sampled_indices, :]\n    sampled_output = HeadOutputs(coefs=sampled_coefs, eta=sampled_eta)\n    return sampled_output, sampled_labels\n\n\ndef get_subset_prompts_batch(output, labels, size, batch_size):\n    num_prompts, num_models = output.coefs.shape\n    sampled_indices = torch.randint(low=0, high=num_prompts, size=(batch_size, size))\n    sampled_coefs = output.coefs[sampled_indices]\n\n    sampled_eta = None\n    if output.eta is not None:\n        sampled_eta = output.eta[sampled_indices]\n    sampled_labels = labels[sampled_indices]\n\n    sampled_output = HeadOutputs(coefs=sampled_coefs, eta=sampled_eta)\n\n    return sampled_output, sampled_labels\n\n\ndef get_ith_output(output, i):\n    betas = output.coefs[i]\n    eta = output.eta[i] if output.eta is not None else None\n    return HeadOutputs(coefs=betas, eta=eta)\n\n\ndef save_output(results, local_dir, hf_dir, file_name):\n    if not local_dir and not hf_dir:\n        raise ValueError(\"Specify a directory for outputs.\")\n\n    results[\"params\"][\"output_file_name\"] = file_name\n\n    file_name += \".json\"\n    if local_dir:\n        path = os.path.join(local_dir, file_name)\n        with open(path, \"w\") as file:\n            json.dump(results, file, indent=4, separators=(\",\", \": \"))\n    if hf_dir:\n        output = json.dumps(results, indent=4, separators=(\",\", \": \"))\n        tmp_file = io.BytesIO(output.encode(\"utf-8\"))\n\n        upload_file(\n            path_or_fileobj=tmp_file,\n            path_in_repo=file_name,\n            repo_id=hf_dir,\n            repo_type=\"model\",\n        )\n\n\ndef simple_metrics(metrics, output, labels, loss_type):\n    results = {}\n\n    for metric in tqdm(metrics, desc=\"Simple Metrics\", unit=\"metrics\"):\n        metric_dict = registered_simple_metrics[loss_type]\n        metric_func = metric_dict[metric]\n        metric_val = metric_func(head_output=output, labels=labels, loss_type=loss_type)\n\n        results[metric] = (\n            round(metric_val, 4) if isinstance(metric_val, float) else metric_val\n        )\n\n    return results\n\n\ndef category_metrics(\n    metrics,\n    output,\n    labels,\n    loss_type,\n    model_type,\n    model_list,\n    ground_truth,\n    arena_rankings,\n):\n    results = {}\n\n    aggr_func_model = registered_helpers[model_type][\"aggregrate\"]\n    # our default ground truth is marginal-gt but we can switch to arena or add configurability if desired\n    aggr_func_gt = registered_helpers[ground_truth][\"aggregrate\"]\n\n    model_output = aggr_func_model(\n        head_output=output, labels=labels, model_list=model_list, loss_type=loss_type\n    )\n    gt_output = aggr_func_gt(\n        labels=labels,\n        model_list=model_list,\n        loss_type=loss_type,\n        arena_rankings=arena_rankings,\n    )\n\n    for metric in tqdm(metrics, desc=\"Category Metrics\", unit=\"metric\"):\n        metric_dict = registered_aggr_metrics[loss_type]\n        metric_func = metric_dict[metric]\n        metric_val = metric_func(\n            gt_output=gt_output,\n            model_output=model_output,\n            model_list=model_list,\n            loss_type=loss_type,\n            labels=labels,\n        )\n        results[metric] = (\n            round(metric_val, 4) if isinstance(metric_val, float) else metric_val\n        )\n\n    return results\n\n\ndef random_subset_metrics(\n    metrics,\n    output,\n    labels,\n    subset_sizes,\n    trials_per_subset,\n    loss_type,\n    model_type,\n    model_list,\n):\n    results = {}\n\n    aggr_func_model = registered_helpers[model_type][\"aggregrate\"]\n    # our default ground truth is marginal-gt but we can switch to arena or add configurability if desired\n    aggr_func_gt = registered_helpers[\"marginal-gt\"][\"aggregrate\"]\n\n    for idx, size in enumerate(subset_sizes):\n        size = int(size)\n        subset_results = {metric: 0 for metric in metrics}\n\n        for _ in tqdm(\n            range(trials_per_subset[idx]),\n            desc=f\"Random Subset size {size}\",\n            unit=\"trial\",\n        ):\n            sample_output, sample_labels = get_subset_prompts(output, labels, size)\n\n            model_output = aggr_func_model(\n                head_output=sample_output,\n                labels=sample_labels,\n                model_list=model_list,\n                loss_type=loss_type,\n            )\n            gt_output = aggr_func_gt(\n                labels=sample_labels, model_list=model_list, loss_type=loss_type\n            )\n\n            for metric in metrics:\n                metric_dict = registered_aggr_metrics[loss_type]\n                metric_func = metric_dict[metric]\n                metric_val = metric_func(\n                    gt_output=gt_output,\n                    model_output=model_output,\n                    model_list=model_list,\n                    loss_type=loss_type,\n                )\n\n                subset_results[metric] += metric_val\n\n        for metric in metrics:\n            subset_results[metric] = round(\n                subset_results[metric] / trials_per_subset, 4\n            )\n\n        results[size] = subset_results\n\n    return results\n\n\ndef aggr_scale_metrics(\n    metrics,\n    output,\n    labels,\n    subset_sizes,\n    trials_per_subset,\n    loss_type,\n    model_type,\n    model_list,\n    arena_rankings,\n    gt,\n):\n    results = {}\n    aggr_func_model = registered_helpers[model_type][\"aggregrate-batch\"]\n    # our default ground truth is arena ranking but we can switch to arena or add configurability if desired\n\n    aggr_func_gt = registered_helpers[gt][\"aggregrate\"]\n    gt_output = aggr_func_gt(\n        labels=labels,\n        model_list=model_list,\n        loss_type=loss_type,\n        arena_rankings=arena_rankings,\n    )\n\n    # TODO: arbitray threshold to limit memory consumption for batching\n    # max_prompts_times_samples_squared = 2e4\n\n    for idx, size in enumerate(subset_sizes):\n        size = int(size)\n        num_samples = int(trials_per_subset[idx])\n\n        subset_results = {metric: 0 for metric in metrics}\n\n        # num_full_mini_batches = int(max(\n        #     1, (size * (num_samples ** 2)) // max_prompts_times_samples_squared\n        # ))\n\n        num_full_mini_batches = int(max(1, num_samples // 100))\n\n        mini_batch_size = num_samples // num_full_mini_batches\n        leftover = num_samples - (num_full_mini_batches * mini_batch_size)\n\n        with tqdm(total=num_samples, desc=f\"Aggr Subset Size {size}\") as pbar:\n\n            def run_mini_batch(batch_count):\n                sample_output, sample_labels = get_subset_prompts_batch(\n                    output, labels, size, batch_count\n                )\n                batch_output = aggr_func_model(\n                    head_output=sample_output,\n                    labels=sample_labels,\n                    model_list=model_list,\n                    loss_type=loss_type,\n                )\n\n                for cur_output in batch_output:\n                    for metric in metrics:\n                        metric_dict = registered_aggr_metrics[loss_type]\n                        metric_func = metric_dict[metric]\n                        metric_val = metric_func(\n                            gt_output=gt_output,\n                            model_output=cur_output,\n                            model_list=model_list,\n                            loss_type=loss_type,\n                        )\n                        subset_results[metric] += metric_val\n                    pbar.update(1)\n\n            for _ in range(num_full_mini_batches):\n                run_mini_batch(mini_batch_size)\n\n            if leftover > 0:\n                run_mini_batch(leftover)\n\n        for metric in metrics:\n            subset_results[metric] = round(\n                subset_results[metric] / float(trials_per_subset[idx]), 4\n            )\n\n        results[size] = subset_results\n\n    return results\n\n\ndef get_metrics(\n    val_data, train_data, arena_rankings, val_model_list, train_model_list, args\n):\n    results = {}\n    to_inc = set(args.metrics_to_inc)\n    output_label_func = registered_helpers[args.model_type][\"output_labels\"]\n    output, labels = output_label_func(\n        val_data=val_data,\n        train_data=train_data,\n        arena_rankings=arena_rankings,\n        loss_type=args.loss_type,\n        model_list=val_model_list,\n        train_model_list=train_model_list,\n    )\n\n    if \"simple\" in to_inc:\n        simple_results = simple_metrics(\n            metrics=args.simple_metrics,\n            output=output,\n            labels=labels,\n            loss_type=args.loss_type,\n        )\n        results[\"simple_metrics\"] = simple_results\n\n    if \"category\" in to_inc:\n        category_results = category_metrics(\n            metrics=args.category_metrics,\n            loss_type=args.loss_type,\n            model_type=args.model_type,\n            model_list=val_model_list,\n            output=output,\n            labels=labels,\n            ground_truth=args.ground_truth,\n            arena_rankings=arena_rankings,\n        )\n        results[\"category_metrics\"] = category_results\n\n    if \"random_subsets\" in to_inc:\n        subset_results = random_subset_metrics(\n            metrics=args.rand_subset_metrics,\n            subset_sizes=args.rand_subset_sizes,\n            trials_per_subset=args.rand_num_samples,\n            loss_type=args.loss_type,\n            model_type=args.model_type,\n            model_list=val_model_list,\n            output=output,\n            labels=labels,\n        )\n        results[\"random_subsets\"] = subset_results\n\n    if \"aggr_scale\" in to_inc:\n        scale_results = aggr_scale_metrics(\n            metrics=args.aggr_scale_metrics,\n            subset_sizes=args.aggr_scale_subset_sizes,\n            trials_per_subset=args.aggr_scale_num_samples,\n            loss_type=args.loss_type,\n            model_type=args.model_type,\n            model_list=val_model_list,\n            output=output,\n            labels=labels,\n            arena_rankings=arena_rankings,\n            gt=args.ground_truth,\n        )\n        results[\"aggr_scale\"] = scale_results\n\n    return results\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    # model repo contains model list and potentially, eval data (eval_output.jsonl)\n    parser.add_argument(\"--model_repo\", type=str, default=None)\n    parser.add_argument(\"--model_list_path\", type=str, default=None)\n\n    # val data is either in model repo, local file, or remotely as checkpoint file\n    parser.add_argument(\"--eval_path\", nargs=\"+\", type=str, default=None)\n    parser.add_argument(\"--checkpoint_path\", nargs=\"+\", type=str, default=None)\n    parser.add_argument(\"--hf_checkpoint_repo\", type=str, default=None)\n    parser.add_argument(\"--hf_checkpoint_file\", nargs=\"+\", type=str, default=None)\n\n    parser.add_argument(\"--output_dir\", type=str, default=None)\n    parser.add_argument(\"--hf_output_dir\", type=str, default=None)\n    parser.add_argument(\n        \"--output_file_name\", type=str, nargs=\"+\", default=[\"eval_metrics\"]\n    )\n\n    parser.add_argument(\"--hf_train_dataset\", type=str, default=None)\n    parser.add_argument(\"--train_path\", type=str, default=None)\n\n    parser.add_argument(\"--arena_path\", type=str, default=None)\n\n    parser.add_argument(\"--loss_type\", type=str, default=\"bt\", help=\"bt, bt_tie, rk\")\n    parser.add_argument(\n        \"--model_type\", type=str, default=\"p2l\", help=\"p2l, marginal, arena\"\n    )\n\n    parser.add_argument(\n        \"--categories\",\n        nargs=\"*\",\n        default=[\n            \"all\",\n            \"creative_writing\",\n            \"math\",\n            \"language=Chinese\",\n            \"is_code\",\n            \"hard\",\n        ],\n    )\n\n    parser.add_argument(\n        \"--simple_metrics\",\n        nargs=\"*\",\n        default=[\n            \"Loss\",\n            \"BCELoss\",\n            \"MSELoss\",\n            \"Accuracy\",\n            \"Tie_Loss\",\n            \"Tie_Accuracy\",\n            \"Tie_bb_Accuracy\",\n            \"Tie_bb_Loss\",\n            \"Mean-BT\",\n            \"Std-BT\",\n            \"Spread-BT\",\n            \"Mean-Spread-BT\",\n            \"Mean-IQR-BT\",\n            \"Mean-Std-BT\",\n        ],\n    )\n\n    parser.add_argument(\"--train_checkpoints\", nargs=\"+\", type=int, default=[])\n    parser.add_argument(\"--checkpoint_size\", type=int, default=0)\n\n    # gt is marginal on val\n    parser.add_argument(\n        \"--category_metrics\",\n        nargs=\"*\",\n        default=[\n            \"Leaderboard\",\n            \"Aggr_Loss\",\n            \"Aggr_BCELoss\",\n            \"Aggr_Tie_Loss\",\n            \"Aggr_Tie_Accuracy\",\n            \"Aggr_Tie_bb_Accuracy\",\n            \"Aggr_Tie_bb_Loss\",\n            \"L1-Dist-Prob\",\n            \"Spearman-lbs\",\n            \"Kendall-lbs\",\n            \"IQR-BT\",\n            \"Std-BT\",\n            \"Spread-BT\",\n            \"Top-k-fraction\",\n            \"Top-k-displace\",\n        ],\n    )\n\n    parser.add_argument(\n        \"--rand_subset_sizes\", nargs=\"*\", default=[250, 500, 1000, 2000]\n    )\n\n    parser.add_argument(\"--rand_num_samples\", nargs=\"*\", default=[50, 20, 5, 3])\n    parser.add_argument(\n        \"--rand_subset_metrics\",\n        nargs=\"*\",\n        default=[\"L1-Dist-Prob\", \"Spearman-lbs\", \"Kendall-lbs\"],\n    )\n    # gt is arena leaderboard\n    parser.add_argument(\n        \"--aggr_scale_subset_sizes\",\n        nargs=\"*\",\n        default=[1, 10, 25, 100, 250, 500, 1000, 2000],\n    )\n    parser.add_argument(\n        \"--aggr_scale_num_samples\",\n        nargs=\"*\",\n        default=[500, 500, 500, 200, 100, 40, 10, 6],\n    )\n\n    parser.add_argument(\n        \"--aggr_scale_metrics\",\n        nargs=\"*\",\n        default=[\"L1-Dist-Prob\", \"Spearman-lbs\", \"Kendall-lbs\"],\n    )\n    parser.add_argument(\"--ground_truth\", type=str, default=\"marginal-gt\")\n\n    parser.add_argument(\n        \"--metrics_to_inc\",\n        nargs=\"*\",\n        default=[\"simple\", \"category\", \"random_subsets\", \"aggr_scale\"],\n    )\n\n    parser.add_argument(\"--remove_last_hidden_json\", default=True)\n\n    args = parser.parse_args()\n    start_time = time.time()\n    for idx in range(len(args.output_file_name)):\n        results = {}\n        results[\"params\"] = copy.deepcopy(vars(args))\n\n        train_model_list = parse_model_list(args.model_repo, args.model_list_path)\n\n        eval_path = args.eval_path[idx] if args.eval_path else None\n        checkpoint_path = args.checkpoint_path[idx] if args.checkpoint_path else None\n        hf_checkpoint_file = (\n            args.hf_checkpoint_file[idx] if args.hf_checkpoint_file else None\n        )\n\n        # make sure right params are dumped\n        results[\"params\"][\"eval_path\"] = eval_path\n        results[\"params\"][\"checkpoint_path\"] = checkpoint_path\n        results[\"params\"][\"hf_checkpoint_file\"] = hf_checkpoint_file\n\n        val_data, val_model_list = parse_eval_output_data(\n            args.model_repo,\n            eval_path,\n            checkpoint_path,\n            args.hf_checkpoint_repo,\n            hf_checkpoint_file,\n            args.loss_type,\n            train_model_list,\n            args.remove_last_hidden_json,\n        )\n\n        train_data = parse_train_data(\n            args.hf_train_dataset, args.train_path, args.loss_type, train_model_list\n        )\n        arena_data = parse_arena_data(args.arena_path)\n\n        models = {}\n        for category in args.categories:\n\n            cat_val_data = filter_battle_data(val_data, category)\n            cat_train_data = filter_battle_data(train_data, category)\n\n            arena_rankings = get_arena_rankings(arena_data, category)\n\n            current_model = str(args.model_type) + \"-\" + category\n            models[current_model] = get_metrics(\n                cat_val_data,\n                cat_train_data,\n                arena_rankings,\n                val_model_list,\n                train_model_list,\n                args,\n            )\n\n            # merely for marginal train checkpointing\n            for checkpoint in args.train_checkpoints:\n                num_data = checkpoint * args.checkpoint_size\n                checkpoint_train_data = train_data.head(num_data)\n\n                cat_train_data = filter_battle_data(checkpoint_train_data, category)\n                models[current_model + f\"-checkpoint-{checkpoint}\"] = get_metrics(\n                    cat_val_data,\n                    cat_train_data,\n                    arena_rankings,\n                    val_model_list,\n                    train_model_list,\n                    args,\n                )\n\n        results[\"models\"] = models\n        save_output(\n            results, args.output_dir, args.hf_output_dir, args.output_file_name[idx]\n        )\n\n    end_time = time.time()\n    total_time = end_time - start_time\n\n    minutes = int(total_time // 60)\n    seconds = int(total_time % 60)\n\n    print(f\"\\nTotal time taken: {minutes} minutes and {seconds} seconds\")\n"
  },
  {
    "path": "p2l/dataset.py",
    "content": "from transformers import PreTrainedTokenizer\nfrom datasets import Dataset, DatasetDict, load_dataset, load_from_disk\nimport torch\nfrom typing import List\n\n\ndef get_model_list(dataset: Dataset):\n\n    model_a_values = dataset.unique(\"model_a\")\n    model_b_values = dataset.unique(\"model_b\")\n\n    model_list_with_repeats = []\n\n    for value in model_a_values:\n        model_list_with_repeats.append(value)\n\n    for value in model_b_values:\n        model_list_with_repeats.append(value)\n\n    model_set = set(model_list_with_repeats)\n\n    model_list = sorted(list(model_set))\n\n    return model_list\n\n\ndef get_dataset(path: str, split: str, from_disk=False):\n    if from_disk:\n        dataset = load_from_disk(path)\n\n        if isinstance(dataset, DatasetDict):\n        \n            dataset = dataset[split]\n\n        return dataset\n    else:\n        return load_dataset(path, split=split)\n\n\ndef _translate_label(\n    labels: List[int], train_model_list: List[str], val_model_list: List[str]\n) -> List[int]:\n    label_copy = labels[:]\n\n    label_copy[0] = train_model_list.index(val_model_list[labels[0]])\n    label_copy[1] = train_model_list.index(val_model_list[labels[1]])\n\n    return label_copy\n\n\ndef translate_val_data(\n    val_data: Dataset, train_model_list: List[str], val_model_list: List[str]\n) -> Dataset:\n\n    # Validate val models\n    for val_model in val_model_list:\n        assert val_model in train_model_list, val_model\n\n    # Translate val dataset\n    val_data = val_data.map(\n        lambda labels: {\n            \"labels\": _translate_label(labels, train_model_list, val_model_list)\n        },\n        input_columns=\"labels\",\n        num_proc=16,\n    )\n\n    return val_data\n\n\nclass DataCollator:\n    def __init__(self, tokenizer, max_length, weight=None, reweight_scale=None):\n        self.tokenizer: PreTrainedTokenizer = tokenizer\n        self.max_length: int = max_length\n        self.weight: bool = weight\n        self.reweight_scale: float = reweight_scale\n        self.first = True\n\n    def __call__(self, data):\n\n        prompts = []\n\n        for seq in data:\n\n            if isinstance(seq[\"prompt\"], str):\n                prompts.append([{\"role\": \"user\", \"content\": seq[\"prompt\"]}])\n            else:\n                prompts.append([{\"role\": \"user\", \"content\": turn} for turn in seq[\"prompt\"]])\n        \n        labels = torch.tensor([seq[\"labels\"].tolist() for seq in data])\n\n        formatted_prompts = self.tokenizer.apply_chat_template(\n            prompts,\n            tokenize=False,\n            add_generation_prompt=False,\n            add_special_tokens=False,\n        )\n\n        # Scrub any instances of cls token from the data, otherwise model will error.\n        formatted_prompts = [\n            prompt.replace(self.tokenizer.cls_token, \"<cls>\")\n            for prompt in formatted_prompts\n        ]\n\n        formatted_prompts = [\n            seq + self.tokenizer.cls_token for seq in formatted_prompts\n        ]\n\n        if self.first:\n            print(formatted_prompts)\n            self.first = False\n\n        encoded = self.tokenizer(\n            formatted_prompts,\n            padding=True,\n            return_tensors=\"pt\",\n            add_special_tokens=False,\n            truncation=True,\n            max_length=self.max_length,\n        )\n\n        out = {\n            \"input_ids\": encoded[\"input_ids\"],\n            \"attention_mask\": encoded[\"attention_mask\"],\n            \"labels\": labels,\n        }\n\n        if self.weight:\n            if \"weight\" in data[0]:\n                out[\"weights\"] = torch.tensor([seq[\"weight\"].tolist() for seq in data])\n                if self.reweight_scale:\n                    out[\"weights\"] *= self.reweight_scale\n            else:\n                out[\"weights\"] = None\n\n        return out\n"
  },
  {
    "path": "p2l/endpoint.py",
    "content": "import argparse\nimport json\nfrom typing import Dict, Tuple, List, Optional\n\nimport torch\nimport uvicorn\nfrom fastapi import FastAPI, Header, HTTPException\nfrom huggingface_hub import hf_hub_download\nfrom pydantic import BaseModel\nfrom transformers import (\n    AutoTokenizer,\n    TextClassificationPipeline,\n    pipeline,\n    PreTrainedModel,\n)\n\nfrom p2l.model import get_p2l_model, P2LOutputs\nfrom contextlib import asynccontextmanager\nimport logging\n\nlogging.getLogger().setLevel(logging.DEBUG)\n\n\ndef parse_args():\n\n    parser = argparse.ArgumentParser(description=\"Run FastAPI with Uvicorn\")\n\n    parser.add_argument(\n        \"--model-path\",\n        \"-m\",\n        type=str,\n        default=\"p2el/Qwen2.5-7B-Instruct-rk-full-train\",\n        help=\"Path to the model repository\",\n    )\n    parser.add_argument(\n        \"--model-type\",\n        \"-mt\",\n        type=str,\n        default=\"qwen2\",\n        help=\"Type of the model\",\n    )\n    parser.add_argument(\n        \"--head-type\",\n        \"-ht\",\n        type=str,\n        default=\"rk\",\n        help=\"Type of model head\",\n    )\n    parser.add_argument(\n        \"--loss-type\",\n        \"-lt\",\n        type=str,\n        default=\"rk\",\n        help=\"Type of the loss function\",\n    )\n    parser.add_argument(\n        \"--api-key\",\n        \"-a\",\n        type=str,\n        default=\"-\",\n        help=\"API key for authorization\",\n    )\n    parser.add_argument(\n        \"--host\",\n        \"-H\",\n        type=str,\n        default=\"0.0.0.0\",\n        help=\"Host to run the server on\",\n    )\n    parser.add_argument(\n        \"--port\",\n        \"-p\",\n        type=int,\n        default=10250,\n        help=\"Port to run the server on\",\n    )\n\n    parser.add_argument(\n        \"--reload\",\n        action=argparse.BooleanOptionalAction,\n        default=True,\n        help=\"Whether to reload the endpoint on detected code change, needs workers to be 1.\",\n    )\n    parser.add_argument(\n        \"--workers\",\n        type=int,\n        default=1,\n        help=\"Number of endpoint workers (will hold a model per worker).\",\n    )\n    parser.add_argument(\n        \"--cuda\",\n        action=argparse.BooleanOptionalAction,\n        default=True,\n        help=\"Flag to enable using a GPU to host the model. Flag is true by default.\",\n    )\n\n    args = parser.parse_args()\n\n    return args\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n\n    args = parse_args()\n\n    model, tokenizer, model_list = load_model(\n        args.model_path,\n        args.model_type,\n        args.head_type,\n        args.loss_type,\n    )\n\n    pipe = pipeline(\n        task=\"text-classification\",\n        model=model,\n        tokenizer=tokenizer,\n        device=\"cuda\" if args.cuda else \"cpu\",\n        pipeline_class=P2LPipeline,\n    )\n\n    app.state.api_key = args.api_key\n    app.state.model_list = model_list\n    app.state.model = model\n    app.state.tokenizer = tokenizer\n    app.state.pipe = pipe\n\n    try:\n\n        yield\n\n    finally:\n\n        pass\n\n\n# Initialize FastAPI app\napp = FastAPI(lifespan=lifespan)\n\n\n# Define the input data structure\nclass InputData(BaseModel):\n    prompt: list[str]\n\n\nclass OutputData(BaseModel):\n    coefs: List[float]\n    eta: Optional[float] = None\n\n\nclass ModelList(BaseModel):\n    models: List[str]\n\n\nclass P2LPipeline(TextClassificationPipeline):\n    def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, torch.Tensor]:\n        return_tensors = self.framework\n\n        inputs = inputs[\"prompt\"]\n\n        messages = [{\"role\": \"user\", \"content\": p} for p in inputs]\n\n        formatted = self.tokenizer.apply_chat_template(\n            messages,\n            tokenize=False,\n            add_generation_prompt=False,\n            add_special_tokens=False,\n        )\n        formatted = formatted + self.tokenizer.cls_token\n\n        logging.debug(f\"Formatted input: {formatted}\")\n\n        return self.tokenizer(\n            formatted,\n            return_tensors=return_tensors,\n            max_length=8192,\n            padding=\"longest\",\n            truncation=True,\n        )\n\n    def postprocess(\n        self, model_outputs: P2LOutputs, function_to_apply=None, top_k=1, _legacy=True\n    ):\n        model_outputs = P2LOutputs(model_outputs)\n\n        eta = model_outputs.eta\n\n        return OutputData(\n            coefs=model_outputs.coefs.cpu().float().tolist()[0],\n            eta=eta.cpu().float().item() if eta else None,\n        )\n\n\n@app.post(\"/predict\")\nasync def predict(input_data: InputData, api_key: str = Header(...)):\n\n    logging.debug(f\"Received Request: {input_data}.\")\n\n    if api_key != app.state.api_key:\n\n        raise HTTPException(status_code=403, detail=\"Unauthorized\")\n\n    try:\n        pipe: P2LPipeline = app.state.pipe\n\n        logging.debug(f\"Input Prompt: {input_data.prompt}\")\n\n        output = pipe(inputs=input_data.model_dump())\n\n        logging.debug(f\"Output: {output}\")\n\n        return output\n\n    except Exception as e:\n\n        logging.debug(e)\n\n        raise HTTPException(status_code=500, detail=str(e))\n\n\n@app.get(\"/models\")\nasync def models(api_key: str = Header(...)):\n\n    logging.debug(f\"Received Model List Request.\")\n\n    if api_key != app.state.api_key:\n\n        raise HTTPException(status_code=403, detail=\"Unauthorized\")\n\n    try:\n\n        return ModelList(\n            models=app.state.model_list,\n        )\n\n    except Exception as e:\n\n        raise HTTPException(status_code=500, detail=str(e))\n\n\ndef load_model(\n    model_name, model_type, head_type, loss_type\n) -> Tuple[PreTrainedModel, AutoTokenizer, List[str]]:\n\n    # Download and load the model list\n    fname = hf_hub_download(\n        repo_id=model_name, filename=\"model_list.json\", repo_type=\"model\"\n    )\n    with open(fname) as fin:\n        model_list = json.load(fin)\n\n    # Initialize tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    tokenizer.truncation_side = \"left\"\n    tokenizer.padding_side = \"right\"\n\n    # Get the model class and load the model\n    model_cls = get_p2l_model(model_type, loss_type, head_type)\n\n    model = model_cls.from_pretrained(\n        model_name,\n        CLS_id=tokenizer.cls_token_id,\n        num_models=len(model_list),\n        torch_dtype=torch.bfloat16,\n    )\n    return model, tokenizer, model_list\n\n\nif __name__ == \"__main__\":\n\n    args = parse_args()\n\n    uvicorn.run(\n        \"p2l.endpoint:app\",\n        port=args.port,\n        host=args.host,\n        reload=args.reload,\n        workers=args.workers,\n    )\n"
  },
  {
    "path": "p2l/eval.py",
    "content": "import argparse\nfrom p2l.model import get_p2l_model, P2LOutputs\nfrom transformers import pipeline, TextClassificationPipeline, AutoTokenizer\nfrom huggingface_hub import hf_hub_download\nfrom datasets import load_dataset\nimport torch\nfrom typing import Dict\nimport pandas as pd\nimport os\nimport json\nfrom tqdm.auto import tqdm\nfrom torch.utils.data import Dataset\nfrom glob import glob\n\n\nclass P2LPipeline(TextClassificationPipeline):\n    def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, torch.Tensor]:\n        return_tensors = self.framework\n\n        messages = [{\"role\": \"user\", \"content\": inputs}]\n\n        formatted = self.tokenizer.apply_chat_template(\n            messages,\n            tokenize=False,\n            add_generation_prompt=False,\n            add_special_tokens=False,\n        )\n\n        formatted = formatted + self.tokenizer.cls_token\n\n        return self.tokenizer(\n            formatted,\n            return_tensors=return_tensors,\n            max_length=8192,\n            padding=\"longest\",\n            truncation=True,\n        )\n\n    def postprocess(\n        self, model_outputs: P2LOutputs, function_to_apply=None, top_k=1, _legacy=True\n    ):\n\n        model_outputs = P2LOutputs(model_outputs)\n\n        eta = model_outputs.eta\n        gamma = model_outputs.gamma\n\n\n        return dict(\n            coefs=model_outputs.coefs.cpu().float().numpy(),\n            eta=eta.cpu().float().numpy() if eta else None,\n            gamma=gamma.cpu().float().numpy() if gamma else None,\n            last_hidden_state=model_outputs.last_hidden_state.cpu().float().numpy(),\n        )\n\n\nclass ListDataset(Dataset):\n    def __init__(self, original_list):\n        self.original_list = original_list\n\n    def __len__(self):\n        return len(self.original_list)\n\n    def __getitem__(self, i):\n        return self.original_list[i]\n\n\ndef main(args, local_file=None):\n\n    os.makedirs(args.output_dir, exist_ok=True)\n\n    dataset = load_dataset(args.dataset, split=args.dataset_split)\n    \n    if local_file:\n        fname = os.path.join(local_file, \"model_list.json\")\n    else:\n        fname = hf_hub_download(\n            repo_id=args.model_path, filename=\"model_list.json\", repo_type=\"model\"\n        )\n\n    with open(fname) as fin:\n        model_list = json.load(fin)\n\n    model_cls = get_p2l_model(args.model_type, args.loss_type, args.head_type)\n\n    if local_file:\n        tokenizer = AutoTokenizer.from_pretrained(local_file, local_files_only=True)\n        model = model_cls.from_pretrained(\n            local_file,\n            CLS_id=tokenizer.cls_token_id,\n            num_models=len(model_list),\n            torch_dtype=torch.bfloat16,\n            local_files_only=True,\n        )\n    else:\n        tokenizer = AutoTokenizer.from_pretrained(args.model_path)\n        model = model_cls.from_pretrained(\n            args.model_path,\n            CLS_id=tokenizer.cls_token_id,\n            num_models=len(model_list),\n            torch_dtype=torch.bfloat16,\n        )\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    pipe = pipeline(\n        task=\"text-classification\",\n        model=model,\n        tokenizer=tokenizer,\n        device=device,\n        pipeline_class=P2LPipeline,\n    )\n\n    prompts = ListDataset(dataset[\"prompt\"])\n\n    with torch.no_grad():\n        outputs = [\n            out\n            for out in tqdm(\n                pipe(prompts, batch_size=args.batch_size), total=len(prompts)\n            )\n        ]\n\n    df = dataset.to_pandas()\n\n    outputs_df = pd.DataFrame.from_records(outputs)\n\n    if args.drop_hidden:\n\n        outputs_df = outputs_df.drop(\"last_hidden_state\", axis=1)\n\n    df = pd.concat((df, outputs_df), axis=1)\n\n    if local_file:\n        fname = local_file.split(\"/\")[-1] + \".json\"\n    else:\n        fname = args.model_path.split(\"/\")[-1] + \".json\"\n    fpath = os.path.join(args.output_dir, fname)\n    df.to_json(fpath, orient=\"records\", indent=4, force_ascii=False)\n\n    if args.output_hf_path:\n        from datasets import Dataset\n\n        df = pd.read_json(fpath)\n        hf_dataset = Dataset.from_pandas(df)\n        hf_dataset.push_to_hub(args.output_hf_path, private=True)\n        print(\"Results pushed to hub!\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model-path\", \"-m\", type=str, default=None, help=\"Huggingface model path\"\n    )\n    parser.add_argument(\n        \"--training-output-dir\", \"-t\", type=str, default=None\n    )\n    parser.add_argument(\n        \"--dataset\", \"-d\", type=str, required=True, help=\"Huggingface dataset path\"\n    )\n    parser.add_argument(\"--output-hf-path\", \"-oh\", type=str, default=None)\n    parser.add_argument(\n        \"--dataset-split\",\n        \"-ds\",\n        type=str,\n        default=\"train\",\n        help=\"Huggingface dataset split\",\n    )\n    parser.add_argument(\n        \"--model-type\",\n        \"-mt\",\n        type=str,\n        default=\"qwen2\",\n        help=\"Model type (qwen2, llama, etc)\",\n    )\n    parser.add_argument(\n        \"--head-type\",\n        \"-ht\",\n        type=str,\n        default=\"bt\",\n        help=\"Head type (Bradely Terry, Rao-Kupper, etc)\",\n    )\n    parser.add_argument(\n        \"--loss-type\",\n        \"-lt\",\n        type=str,\n        default=\"bt\",\n        help=\"Loss type (Bradely Terry, Rao-Kupper, etc)\",\n    )\n    parser.add_argument(\"--batch-size\", \"-bs\", type=int, default=1, help=\"Batch size\")\n    parser.add_argument(\"--output-dir\", \"-od\", type=str, default=\"outputs\")\n    parser.add_argument(\"--drop-hidden\", action=argparse.BooleanOptionalAction, default=False)\n\n    args = parser.parse_args()\n\n    if args.training_output_dir:\n        for file in glob(os.path.join(args.training_output_dir, \"*\")):\n            main(args, file)\n    else:\n        main(args)\n"
  },
  {
    "path": "p2l/model.py",
    "content": "import torch\nfrom transformers import (\n    Qwen2Model,\n    Qwen2PreTrainedModel,\n    LlamaModel,\n    LlamaPreTrainedModel,\n    PreTrainedModel,\n    AutoTokenizer,\n)\nfrom transformers.utils import ModelOutput\nfrom dataclasses import dataclass\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Dict, Tuple, Callable, Optional\n\n\nregistered_transformers: Dict[str, Tuple[PreTrainedModel, PreTrainedModel]] = {\n    \"qwen2\": (Qwen2PreTrainedModel, Qwen2Model),\n    \"llama\": (LlamaPreTrainedModel, LlamaModel),\n}\n\nregistered_losses: Dict[str, Callable] = {}\nregistered_heads: Dict[str, nn.Module] = {}\nregistered_inits: Dict[str, Callable] = {}\n\nregistered_aggr_models: Dict[str, nn.Module] = {}\nregistered_pairwise_losses: Dict[str, Callable] = {}\n\n\ndef register_loss(name: str):\n    def decorator(func: Callable):\n        registered_losses[name] = func\n        return func\n\n    return decorator\n\n\ndef register_head(name: str):\n    def decorator(func: Callable):\n        registered_heads[name] = func\n        return func\n\n    return decorator\n\n\ndef register_init(name: str):\n    def decorator(func: Callable):\n        registered_inits[name] = func\n        return func\n\n    return decorator\n\n\ndef register_aggr_model(name: str):\n    def decorator(func: Callable):\n        registered_aggr_models[name] = func\n        return func\n\n    return decorator\n\n\ndef register_pairwise_loss(name: str):\n    def decorator(func: Callable):\n        registered_pairwise_losses[name] = func\n        return func\n\n    return decorator\n\n\ndef register_init(name: str):\n    def decorator(func: Callable):\n        registered_inits[name] = func\n        return func\n\n    return decorator\n\n\n@dataclass\nclass HeadOutputs(ModelOutput):\n    coefs: torch.FloatTensor = None\n    eta: Optional[torch.FloatTensor] = None\n    gamma: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass P2LOutputs(ModelOutput):\n    coefs: torch.FloatTensor = None\n    eta: Optional[torch.FloatTensor] = None\n    gamma: Optional[torch.FloatTensor] = None\n    loss: Optional[torch.FloatTensor] = None\n    last_hidden_state: torch.FloatTensor = None\n\n\n@register_loss(\"bt\")\ndef BT_loss(\n    head_output: HeadOutputs,\n    labels: torch.Tensor,\n    weights: torch.Tensor = None,\n    **kwargs,\n):\n    # labels columns are in the form (winner_idx, loser_idx)\n\n    coefs = head_output.coefs\n\n    paired_coefs = coefs.gather(dim=-1, index=labels).contiguous()\n\n    paired_delta_logit = (\n        paired_coefs[:, 0] - paired_coefs[:, 1]\n    )  # subtract winner bt from loser bt\n\n    neg_log_sigma = -F.logsigmoid(paired_delta_logit)  # get neg log prob\n\n    if weights is not None:\n        neg_log_sigma = neg_log_sigma * weights\n\n    loss = neg_log_sigma.mean()\n\n    return loss\n\n\n@register_loss(\"bt-tie\")\ndef BT_tie_loss(\n    head_output: HeadOutputs,\n    labels: torch.Tensor,\n    weights: torch.Tensor = None,\n    **kwargs,\n):\n    # labels columns are in the form (winner_idx, loser_idx, tie_indicator)\n\n    coefs = head_output.coefs\n\n    model_idx = labels[:, :2]  # (batch_dim, 2)\n    tie_ind = labels[:, -1]\n\n    paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()\n\n    paired_delta_logit = (\n        paired_coefs[:, 0] - paired_coefs[:, 1]\n    )  # subtract winner bt from loser bt\n\n    # computes bradley-terry loss where tie is half win and half loss\n    neg_log_sigma = -1 * torch.where(\n        tie_ind == 0,\n        F.logsigmoid(paired_delta_logit),\n        0.5\n        * (F.logsigmoid(paired_delta_logit) + F.logsigmoid(-1 * paired_delta_logit)),\n    )\n\n    if weights is not None:\n        neg_log_sigma = neg_log_sigma * weights\n\n    loss = neg_log_sigma.mean()\n\n    return loss\n\n\nBETA = 0.1\n\n\n@register_loss(\"rk\")\ndef RK_Loss(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n    # labels columns are in form (winner_idx, loser_idx, tie_indicator)\n    coefs = head_output.coefs\n    # eta = torch.exp(head_output.eta).squeeze(-1)  # eta > 0\n    eta = torch.clamp(\n        torch.nn.functional.softplus(head_output.eta - 22.5, BETA).squeeze(-1), min=0.02\n    )\n    # eta = torch.abs(head_output.eta).squeeze(-1)\n    model_idx = labels[:, :2]  # (batch_dim, 2)\n    paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()\n\n    paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]\n\n    # compute RK probabilities\n    p_w = torch.sigmoid(paired_delta_logit - eta)\n    p_l = torch.sigmoid(-1 * paired_delta_logit - eta)\n    p_t = 1 - p_w - p_l\n\n    # point-wise likelihood\n    A = torch.stack((p_w, p_t))  # (2, batch_dim)\n\n    tie_ind = labels[:, -1].unsqueeze(0)  # (1, batch_dim)\n    p = A.take_along_dim(dim=0, indices=tie_ind)\n\n    # mathematically p_t < 1 always but bfloat smh\n    p = torch.clamp(p, min=1e-3)\n\n    # eps = 1e-10\n    loss = -torch.log(p)\n\n    if weights:\n        loss = loss * weights\n\n    loss = loss.mean()\n\n    return loss\n\n\n@register_loss(\"rk-reparam\")\ndef RK_Reparam_Loss(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n\n    coefs = head_output.coefs\n    eta = head_output.eta\n\n    theta = torch.exp(eta) + 1.000001\n\n    winner_idx = labels[:, 0:1]\n    loser_idx = labels[:, 1:2]\n\n    beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()\n    beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()\n\n    pi_win = torch.exp(beta_win)\n    pi_lose = torch.exp(beta_lose)\n\n    p_win = pi_win / (pi_win + theta * pi_lose + 1.0)\n\n    p_lose = pi_lose / (pi_lose + theta * pi_win + 1.0)\n\n    p_tie = 1.0 - p_win - p_lose\n\n    assert p_win.shape == p_lose.shape == p_tie.shape\n\n    P = torch.hstack((p_win, p_tie))\n    tie_ind = labels[:, -1].unsqueeze(-1)\n\n    p = P.gather(dim=-1, index=tie_ind).contiguous()\n\n    p = torch.clamp(p, min=1e-6)\n\n    loss = -torch.log(p)\n\n    if weights:\n        loss = loss * weights\n\n    loss = loss.mean()\n\n    return loss\n\n\n@register_loss(\"ba\")\ndef BA_loss(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n    # labels are (winner_idx, loser_idx, tie_indicator (0 for no tie, 1 for tie, 2 for tie both bad))\n\n    coefs = head_output.coefs\n    eta = head_output.eta\n    gamma = head_output.gamma\n\n    theta = torch.exp(eta) + 1.02\n\n    winner_idx = labels[:, 0:1]\n    loser_idx = labels[:, 1:2]\n\n    beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()\n    beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()\n\n    pi_win = torch.exp(beta_win)\n    pi_lose = torch.exp(beta_lose)\n    pi_gamma = torch.exp(gamma)\n\n    p_win = pi_win / (pi_win + theta * pi_lose + pi_gamma)\n\n    p_lose = pi_lose / (pi_lose + theta * pi_win + pi_gamma)\n\n    p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)\n\n    p_tie = 1.0 - p_win - p_lose - p_tie_bb\n\n    P = torch.hstack((p_win, p_tie, p_tie_bb))\n\n    tie_ind = labels[:, -1].unsqueeze(-1)\n\n    p = P.gather(dim=-1, index=tie_ind).contiguous()\n\n    p = torch.clamp(p, min=1e-2)\n\n    loss = -torch.log(p)\n\n    if weights:\n        loss = loss * weights\n\n    loss = loss.mean()\n\n    print(\"loss: \", loss.item())\n\n    return loss\n\n\n@register_loss(\"bag\")\n@register_loss(\"grk\")\ndef GRK_loss(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n    # labels are (winner_idx, loser_idx, tie_indicator (0 for no tie, 1 for tie, 2 for tie both bad))\n\n    coefs = head_output.coefs.float()\n    eta = head_output.eta.float()\n\n    theta = torch.exp(eta) + 1.000001\n\n    winner_idx = labels[:, 0:1]\n    loser_idx = labels[:, 1:2]\n\n    beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()\n    beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()\n\n    pi_win = torch.exp(beta_win)\n    pi_lose = torch.exp(beta_lose)\n    pi_gamma = 1.0\n\n    p_win = pi_win / (pi_win + theta * pi_lose + pi_gamma)\n\n    p_lose = pi_lose / (pi_lose + theta * pi_win + pi_gamma)\n\n    p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)\n\n    p_tie = 1.0 - p_win - p_lose - p_tie_bb\n\n    assert p_win.shape == p_lose.shape == p_tie_bb.shape == p_tie.shape\n    P = torch.hstack((p_win, p_tie, p_tie_bb))\n\n    tie_ind = labels[:, -1].unsqueeze(-1)\n\n    p = P.gather(dim=-1, index=tie_ind).contiguous()\n\n    p = torch.clamp(p, min=1e-6)\n\n    loss = -torch.log(p)\n\n    if weights:\n        loss = loss * weights\n\n    loss = loss.mean()\n\n    # print(\"loss: \", loss.item())\n\n    return loss\n\n\n@register_head(\"bt\")\nclass BTHead(nn.Module):\n    def __init__(\n        self, input_dim, output_dim, linear_head_downsize_factor=None, **kwargs\n    ) -> None:\n        super().__init__()\n\n        if linear_head_downsize_factor:\n            inner_dim = int(output_dim // linear_head_downsize_factor)\n            self.head = nn.Sequential(\n                nn.Linear(in_features=input_dim, out_features=inner_dim, bias=True),\n                nn.Linear(in_features=inner_dim, out_features=output_dim, bias=True),\n            )\n        else:\n            self.head = nn.Linear(\n                in_features=input_dim, out_features=output_dim, bias=True\n            )\n\n    def forward(self, last_hidden_dim: torch.Tensor):\n        coefs = self.head(last_hidden_dim)\n        return HeadOutputs(coefs=coefs)\n\n\n@register_head(\"rk\")\nclass RKHead(nn.Module):\n    def __init__(\n        self,\n        input_dim,\n        output_dim,\n        eta_dim=1,\n        linear_head_downsize_factor=None,\n        eta_downsize=False,\n        **kwargs,\n    ) -> None:\n        super().__init__()\n        # If linear header downsize factor and eta downsize, then eta is calculated off of the downsized dim, not the hidden dim.\n        if linear_head_downsize_factor:\n            inner_dim = output_dim // linear_head_downsize_factor\n            share_layer = nn.Linear(\n                in_features=input_dim, out_features=inner_dim, bias=True\n            )\n            self.head = nn.Sequential(\n                share_layer,\n                nn.Linear(in_features=inner_dim, out_features=output_dim, bias=True),\n            )\n            if eta_downsize:\n                self.eta_head = nn.Sequential(\n                    share_layer,\n                    nn.Linear(in_features=inner_dim, out_features=eta_dim, bias=True),\n                )\n            else:\n                self.eta_head = nn.Linear(\n                    in_features=output_dim, out_features=eta_dim, bias=True\n                )\n        else:\n            self.head = nn.Linear(\n                in_features=input_dim, out_features=output_dim, bias=True\n            )\n            self.eta_head = nn.Linear(\n                in_features=input_dim, out_features=eta_dim, bias=True\n            )\n\n    def forward(self, last_hidden_dim: torch.Tensor):\n        coefs = self.head(last_hidden_dim)\n        eta = self.eta_head(last_hidden_dim)\n\n        return HeadOutputs(coefs=coefs, eta=eta)\n\n\n@register_head(\"ba\")\nclass BAHead(nn.Module):\n    def __init__(\n        self,\n        input_dim,\n        output_dim,\n        linear_head_downsize_factor=None,\n        **kwargs,\n    ) -> None:\n        super().__init__()\n\n        if linear_head_downsize_factor:\n            raise NotImplementedError(\"Sorry I didn't implement this.\")\n\n        self.head = nn.Linear(in_features=input_dim, out_features=output_dim, bias=True)\n        self.eta_head = nn.Linear(in_features=input_dim, out_features=1, bias=True)\n        self.gamma_head = nn.Linear(in_features=input_dim, out_features=1, bias=True)\n\n    def forward(self, last_hidden_dim: torch.Tensor):\n\n        coefs = self.head(last_hidden_dim)\n        eta = self.eta_head(last_hidden_dim)\n        gamma = self.gamma_head(last_hidden_dim)\n\n        return HeadOutputs(coefs=coefs, eta=eta, gamma=gamma)\n\n\n@register_init(\"reset_params\")\ndef reset_params_init(module):\n    return module.reset_parameters()\n\n\n@register_init(\"he_unif\")\ndef he_unif_init(module):\n    return nn.init.kaiming_uniform_(module.weight, nonlinearity=\"sigmoid\")\n\n\n@register_init(\"xavier_unif\")\ndef xavier_unif_init(module):\n    return nn.init.xavier_uniform_(module.weight)\n\n\n@register_init(\"tiny_normal\")\ndef tiny_normal_init(module):\n    return nn.init.kaiming_normal_(module.weight)\n\n\ndef get_p2l_model(\n    model_type: str, loss_type: str, head_type: str, init_type: str = \"reset_params\"\n) -> PreTrainedModel:\n    pretrained_model_cls, model_cls = registered_transformers[model_type]\n\n    criterion = registered_losses[loss_type]\n\n    head_layer = registered_heads[head_type]\n\n    init_func = registered_inits[init_type]\n\n    class CustomPretrainedModel(pretrained_model_cls):\n        \"\"\"Defines the appropriate pretrained class for the given model name.  This is done so that the value head init scheme is correct.\"\"\"\n\n        def _init_weights(self, module):\n            std = self.config.initializer_range\n            if isinstance(module, nn.Linear):\n                init_func(module)  # was reset params\n                if module.bias is not None:\n                    module.bias.data.zero_()\n            elif isinstance(module, nn.Embedding):\n                module.weight.data.normal_(mean=0.0, std=std)\n                if module.padding_idx is not None:\n                    module.weight.data[module.padding_idx].zero_()\n\n    class P2LModel(CustomPretrainedModel):\n        def __init__(\n            self,\n            config,\n            CLS_id,\n            num_models,\n            linear_head_downsize_factor=None,\n            head_kwargs={},\n            **kwargs,\n        ):\n            super().__init__(config)\n\n            self.num_models = num_models\n            self.cls_token_id = CLS_id\n\n            self.model = model_cls(config)\n\n            self.head = head_layer(\n                input_dim=config.hidden_size,\n                output_dim=self.num_models,\n                linear_head_downsize_factor=linear_head_downsize_factor,\n                **head_kwargs,\n            )\n\n            self.post_init()\n\n        def freeze_transformer(self):\n            for param in self.model.parameters():\n                param.requires_grad = False\n\n        def get_input_embeddings(self):\n            return self.model.embed_tokens\n\n        def set_input_embeddings(self, value):\n            self.model.embed_tokens = value\n\n        def forward(self, input_ids, attention_mask, labels=None, weights=None):\n            batch_size = input_ids.shape[0]\n\n            hidden_outputs = self.model(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                output_hidden_states=False,\n            ).last_hidden_state  # (bs, num_token, embed_dim)\n\n            cls_mask = input_ids == self.cls_token_id\n\n            # double check this is getting the current CLS token\n            cls_hidden_dim = hidden_outputs[cls_mask]\n\n            assert (\n                cls_hidden_dim.shape[0] == batch_size\n            ), f\"input ids {input_ids.shape}, cls_mask {cls_mask.shape}, cls_logit {cls_hidden_dim.shape}\"\n\n            head_output = self.head(cls_hidden_dim)\n\n            if labels is not None:\n                loss = criterion(head_output, labels, weights=weights)\n\n                outputs = P2LOutputs(\n                    coefs=head_output.coefs,\n                    last_hidden_state=cls_hidden_dim,\n                    eta=head_output.eta,\n                    gamma=head_output.gamma,\n                    loss=loss,\n                )\n\n            else:\n                outputs = P2LOutputs(\n                    coefs=head_output.coefs,\n                    last_hidden_state=cls_hidden_dim,\n                    eta=head_output.eta,\n                    gamma=head_output.gamma,\n                )\n\n            return outputs\n\n    return P2LModel\n\n\ndef get_tokenizer(\n    tokenizer_name,\n    chat_template,\n    pad_token_if_none=\"<|pad|>\",\n    cls_token_if_none=\"<|cls|>\",\n):\n    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n    tokenizer.truncation_side = \"left\"\n    tokenizer.padding_side = \"right\"\n\n    if chat_template:\n        tokenizer.chat_template = chat_template\n\n    if \"pad_token\" not in tokenizer.special_tokens_map:\n        tokenizer.add_special_tokens({\"pad_token\": pad_token_if_none})\n    if \"cls_token\" not in tokenizer.special_tokens_map:\n        tokenizer.add_special_tokens({\"cls_token\": cls_token_if_none})\n\n    return tokenizer\n\n\n@register_aggr_model(\"bt\")\n@register_aggr_model(\"bt-tie\")\nclass BTAggrModel(nn.Module):\n    def __init__(self, num_models, batch_size=1):\n        super().__init__()\n        self.coefs = nn.Parameter(\n            nn.init.constant_(torch.empty(batch_size, num_models), 0.5)\n        )\n        self.eta = None\n\n    def forward(self):\n        return self.coefs, self.eta\n\n\n@register_aggr_model(\"rk\")\n@register_aggr_model(\"rk-reparam\")\n@register_aggr_model(\"bag\")\n@register_aggr_model(\"grk\")\nclass RKAggrModel(nn.Module):\n    def __init__(self, num_models, batch_size=1):\n        super().__init__()\n        self.coefs = nn.Parameter(\n            nn.init.constant_(torch.empty(batch_size, num_models), 0.5)\n        )\n        self.eta = nn.Parameter(nn.init.constant_(torch.empty(batch_size, 1), 0.1))\n\n    def forward(self):\n        return self.coefs, self.eta\n\n\n@register_pairwise_loss(\"bt\")\n@register_pairwise_loss(\"bt-tie\")\ndef pairwise_batch_BT_loss(\n    real_output: HeadOutputs, aggregated_output: HeadOutputs, true_probs: torch.tensor\n):\n    real_betas = real_output.coefs\n    aggregated_betas = aggregated_output.coefs\n\n    num_prompts, num_models = real_betas.shape[-2], real_betas.shape[-1]\n\n    pair_indices = torch.tensor(\n        [(i, j) for i in range(num_models) for j in range(i + 1, num_models)],\n        dtype=torch.long,\n    )\n\n    beta_i_agg = aggregated_betas[:, pair_indices[:, 0]]\n    beta_j_agg = aggregated_betas[:, pair_indices[:, 1]]\n\n    pred_probs = torch.sigmoid(beta_i_agg - beta_j_agg)\n\n    pred_probs_expanded = pred_probs.unsqueeze(1).expand(-1, num_prompts, -1)\n\n    eps = 1e-9\n    neg_log_prob = -(\n        true_probs * torch.log(pred_probs_expanded + eps)\n        + (1 - true_probs) * torch.log(1 - pred_probs_expanded + eps)\n    )\n\n    batch_losses = neg_log_prob.mean(dim=(1, 2))\n    loss = batch_losses.mean()\n\n    return loss\n\n\n# batched loss\n@register_pairwise_loss(\"rk\")\ndef pairwise_batch_RK_loss(\n    real_output: HeadOutputs, aggregated_output: HeadOutputs, true_probs: torch.tensor\n):\n    real_betas = real_output.coefs\n    num_prompts, num_models = real_betas.shape[-2], real_betas.shape[-1]\n\n    aggregated_betas = aggregated_output.coefs\n    BETA = 0.1\n    aggregated_eta = torch.clamp(\n        torch.nn.functional.softplus(aggregated_output.eta - 22.5, BETA).squeeze(-1),\n        min=0.02,\n    )\n\n    pair_indices = torch.tensor(\n        [(i, j) for i in range(num_models) for j in range(i + 1, num_models)],\n        dtype=torch.long,\n    )\n\n    beta_i_agg = aggregated_betas[:, pair_indices[:, 0]]\n    beta_j_agg = aggregated_betas[:, pair_indices[:, 1]]\n\n    aggregated_eta = aggregated_eta.unsqueeze(-1)\n    pred_probs_win = torch.sigmoid(beta_i_agg - beta_j_agg - aggregated_eta)\n    pred_probs_loss = torch.sigmoid(beta_j_agg - beta_i_agg - aggregated_eta)\n    pred_probs_tie = 1 - pred_probs_win - pred_probs_loss\n\n    pred_probs = torch.stack((pred_probs_win, pred_probs_loss, pred_probs_tie), dim=-1)\n\n    pred_probs_expanded = pred_probs.unsqueeze(1).expand(-1, num_prompts, -1, -1)\n\n    eps = 1e-9\n    neg_log_prob = -torch.sum(true_probs * torch.log(pred_probs_expanded + eps), dim=-1)\n\n    batch_losses = neg_log_prob.mean(dim=(1, 2))\n    loss = batch_losses.mean()\n\n    return loss\n\n\n# batched\n@register_pairwise_loss(\"rk-reparam\")\ndef pairwise_batch_RK_reparam_loss(\n    real_output: HeadOutputs,\n    aggregated_output: HeadOutputs,\n    true_probs: torch.tensor,\n    **kwargs,\n):\n    real_betas = real_output.coefs\n    num_prompts, num_models = real_betas.shape[-2], real_betas.shape[-1]\n\n    aggregated_betas = aggregated_output.coefs\n    aggregrated_theta = torch.exp(aggregated_output.eta) + 1.000001\n\n    pair_indices = torch.tensor(\n        [(i, j) for i in range(num_models) for j in range(i + 1, num_models)],\n        dtype=torch.long,\n    )\n\n    beta_i_agg = aggregated_betas[:, pair_indices[:, 0]]\n    beta_j_agg = aggregated_betas[:, pair_indices[:, 1]]\n\n    pi_win = torch.exp(beta_i_agg)\n    pi_lose = torch.exp(beta_j_agg)\n\n    p_win = pi_win / (pi_win + aggregrated_theta * pi_lose + 1.0)\n    p_lose = pi_lose / (pi_lose + aggregrated_theta * pi_win + 1.0)\n    p_tie = 1.0 - p_win - p_lose\n\n    pred_probs = torch.stack((p_win, p_lose, p_tie), dim=-1)\n    pred_probs_expanded = pred_probs.unsqueeze(1).expand(-1, num_prompts, -1, -1)\n\n    eps = 1e-9\n    neg_log_prob = -torch.sum(true_probs * torch.log(pred_probs_expanded + eps), dim=-1)\n    batch_losses = neg_log_prob.mean(dim=(1, 2))\n    loss = batch_losses.mean()\n\n    return loss\n\n\ndef get_bag_probs(beta_win, beta_lose, gamma, theta):\n    pi_win = torch.exp(beta_win)\n    pi_lose = torch.exp(beta_lose)\n    pi_gamma = 1.0\n\n    p_win = pi_win / (pi_win + theta * pi_lose + pi_gamma)\n\n    p_lose = pi_lose / (pi_lose + theta * pi_win + pi_gamma)\n\n    p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)\n\n    p_tie = 1.0 - p_win - p_lose - p_tie_bb\n\n    return torch.stack((p_win, p_lose, p_tie, p_tie_bb), dim=-1)\n\n\n# batched\n@register_pairwise_loss(\"bag\")\n@register_pairwise_loss(\"grk\")\ndef pairwise_batch_bag_loss(\n    real_output: HeadOutputs,\n    aggregated_output: HeadOutputs,\n    true_probs: torch.tensor,\n    **kwargs,\n):\n    real_betas = real_output.coefs\n    num_prompts, num_models = real_betas.shape[-2], real_betas.shape[-1]\n\n    aggregated_betas = aggregated_output.coefs\n    aggregrated_theta = torch.exp(aggregated_output.eta) + 1.000001\n\n    pair_indices = torch.tensor(\n        [(i, j) for i in range(num_models) for j in range(i + 1, num_models)],\n        dtype=torch.long,\n    )\n\n    beta_i_agg = aggregated_betas[:, pair_indices[:, 0]]\n    beta_j_agg = aggregated_betas[:, pair_indices[:, 1]]\n\n    pred_probs = get_bag_probs(beta_i_agg, beta_j_agg, 1.0, aggregrated_theta)\n\n    pred_probs_expanded = pred_probs.unsqueeze(1).expand(-1, num_prompts, -1, -1)\n\n    eps = 1e-9\n    neg_log_prob = -torch.sum(true_probs * torch.log(pred_probs_expanded + eps), dim=-1)\n    batch_losses = neg_log_prob.mean(dim=(1, 2))\n    loss = batch_losses.mean()\n\n    return loss\n\n\n@register_loss(\"tie-rk\")\ndef RK_Tie_Loss(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n    coefs = head_output.coefs\n    eta = torch.clamp(\n        torch.nn.functional.softplus(head_output.eta - 22.5, BETA).squeeze(-1), min=0.02\n    )\n    model_idx = labels[:, :2]\n    paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()\n\n    paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]\n\n    p_w = torch.sigmoid(paired_delta_logit - eta)\n    p_l = torch.sigmoid(-1 * paired_delta_logit - eta)\n    p_t = 1 - p_w - p_l\n\n    p_not_t = p_w + p_l\n    p_t = p_t\n\n    A = torch.stack((p_not_t, p_t))\n\n    tie_ind = labels[:, -1].unsqueeze(0)\n    p = A.take_along_dim(dim=0, indices=tie_ind)\n\n    p = torch.clamp(p, min=1e-3)\n\n    loss = -torch.log(p)\n    if weights:\n        loss = loss * weights\n    loss = loss.mean()\n\n    return loss\n\n\n@register_loss(\"tie-bag\")\n@register_loss(\"tie-grk\")\ndef bag_tie_loss(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n    coefs = head_output.coefs\n    eta = head_output.eta\n\n    theta = torch.exp(eta) + 1.000001\n\n    winner_idx = labels[:, 0:1]\n    loser_idx = labels[:, 1:2]\n\n    beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()\n    beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()\n\n    p_win, p_lose, p_tie, p_tie_bb = torch.unbind(\n        get_bag_probs(beta_win, beta_lose, 1.0, theta), dim=-1\n    )\n\n    P = torch.hstack((p_win + p_lose, p_tie + p_tie_bb))\n\n    tie_ind = labels[:, -1].unsqueeze(-1)\n    tie_ind = torch.where(tie_ind == 0, 0, 1)  # segment into ties and not ties\n\n    p = P.gather(dim=-1, index=tie_ind).contiguous()\n\n    p = torch.clamp(p, min=1e-6)\n\n    loss = -torch.log(p)\n\n    if weights:\n        loss = loss * weights\n\n    loss = loss.mean()\n    return loss\n\n\n@register_loss(\"tie-bb-bag\")\n@register_loss(\"tie-bb-grk\")\ndef bag_tie_bb_loss(\n    head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs\n):\n    coefs = head_output.coefs\n    eta = head_output.eta\n\n    theta = torch.exp(eta) + 1.000001\n\n    winner_idx = labels[:, 0:1]\n    loser_idx = labels[:, 1:2]\n\n    beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()\n    beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()\n\n    p_win, p_lose, p_tie, p_tie_bb = torch.unbind(\n        get_bag_probs(beta_win, beta_lose, 1.0, theta), dim=-1\n    )\n\n    P = torch.hstack((p_win + p_lose + p_tie, p_tie_bb))\n\n    tie_ind = labels[:, -1].unsqueeze(-1)\n    tie_ind = torch.where(tie_ind == 2, 1, 0)  # index should be 1 if tie-bb\n\n    p = P.gather(dim=-1, index=tie_ind).contiguous()\n\n    p = torch.clamp(p, min=1e-6)\n\n    loss = -torch.log(p)\n\n    if weights:\n        loss = loss * weights\n\n    loss = loss.mean()\n    return loss\n"
  },
  {
    "path": "p2l/train.py",
    "content": "import argparse\nimport os\nimport yaml\nimport json\nimport random\nfrom transformers import Trainer, TrainingArguments, set_seed\nfrom p2l.dataset import DataCollator, get_model_list, get_dataset, translate_val_data\nfrom p2l.model import get_p2l_model, get_tokenizer\nfrom torch.utils.data import Sampler\nfrom typing import Optional\nfrom huggingface_hub import HfApi\n\n# Want control over data ordering, use no shuffle trainer.\nclass NoShuffleTrainer(Trainer):\n    def _get_train_sampler(self) -> Optional[Sampler]:\n        return None\n\n\ndef train_model(args):\n\n    with open(args.config, \"r\") as file:\n        config = yaml.safe_load(file)\n\n    learning_rate = config[\"learning_rate\"]\n    # Microbatch size\n    batch_size = config[\"batch_size\"]\n    # HF data path\n    train_data_path = config[\"train_data_path\"]\n    val_data_path = config[\"val_data_path\"]\n    output_dir = config[\"output_dir\"]\n    pretrain_model_name = config[\"pretrain_model_name\"]\n    # Prompts will be truncted to this length\n    max_length = config[\"max_length\"]\n    gradient_accumulation_steps = config[\"gradient_accumulation_steps\"]\n    # Deepspeed config choices can be found in the deepspeed directory\n    deepspeed_config_path = config[\"deepspeed_config_path\"]\n    # Type of transformer, see model.py for options.\n    model_type = config[\"model_type\"]\n    # Loss type (e.g, bt, rk), see model.py for options.\n    loss_type = config[\"loss_type\"]\n    # The linear head type, see model.py for options.\n    head_type = config[\"head_type\"]\n\n    # Epsilon value for Adam\n    adam_epsilon = config[\"adam_epsilon\"]\n\n    # Optional\n    epochs = config.get(\"num_train_epochs\", 1)\n    lr_scheduler = config.get(\"lr_schedule\", \"constant\")\n    chat_template = config.get(\"chat_template\", None)\n    # Downsize the rank of the classification head.\n    linear_head_downsize_factor = config.get(\"linear_head_downsize_factor\", None)\n    # Whether to weight the loss. If this is true, it expects that the dataset has a \"weight\" column.\n    weighted_loss = config.get(\"weighted_loss\", False)\n    # kwargs for the head init.\n    head_config = config.get(\"head_config\", {})\n    # If the tokenizer/model does not already have a cls token, this will be used.\n    cls_token_if_none = config.get(\"cls_token_if_none\", \"<|cls|>\")\n    # If the tokenizer/model does not already have a pad token, this will be used.\n    pad_token_if_none = config.get(\"pad_token_if_none\", \"<|pad|>\")\n    # If using weighted loss, scalar reweight factor\n    reweight_scale = config.get(\"reweight_scale\", None)\n    proj_name = config.get(\"proj_name\", None)\n    init_type = config.get(\"init_type\", \"reset_params\")\n    train_head_only = config.get(\"train_head_only\", False)\n    load_train_data_from_disk = config.get(\"load_train_data_from_disk\", False)\n    load_val_data_from_disk = config.get(\"load_val_data_from_disk\", False)\n\n    LOCAL_RANK = int(os.environ.get(\"LOCAL_RANK\", -1))\n\n    os.makedirs(output_dir, exist_ok=True)\n\n    # define project name\n    if not proj_name:\n        proj_name = f\"{pretrain_model_name.split('/')[1]}_lr{learning_rate}_bs{batch_size}_ep{epochs}\"\n\n    print(f\"project name: {proj_name}\")\n\n    output_path = os.path.join(output_dir, proj_name)\n\n    if args.checkpoint:\n        resume_from_checkpoint = args.checkpoint\n        print(\"resuming from checkpoint\")\n    else:\n        resume_from_checkpoint = False\n\n    if not resume_from_checkpoint:\n        version = 1\n        while os.path.exists(output_path):\n            output_path = output_path.replace(f\"_{version - 1}\", \"\")\n            output_path = output_path + f\"_{version}\"\n            version += 1\n\n    with open(deepspeed_config_path) as fin:\n        deepspeed_config = json.load(fin)\n\n    random.seed(42)\n    set_seed(42)\n\n    training_args = TrainingArguments(\n        output_dir=output_path,\n        report_to=\"wandb\",\n        run_name=proj_name,\n        num_train_epochs=epochs,\n        gradient_accumulation_steps=gradient_accumulation_steps,\n        save_strategy=\"no\" if args.save_steps == -1 else \"steps\",\n        save_steps=None if args.save_steps == -1 else args.save_steps,\n        save_only_model=True,\n        eval_strategy=\"no\",\n        logging_strategy=\"steps\",\n        logging_steps=1,\n        ddp_timeout=9999999,\n        per_device_train_batch_size=batch_size,\n        per_device_eval_batch_size=batch_size,\n        eval_accumulation_steps=1,\n        eval_steps=args.eval_steps,\n        lr_scheduler_type=lr_scheduler,\n        logging_dir=\"./logs\",\n        fp16=False,\n        bf16=True,\n        learning_rate=learning_rate,\n        adam_epsilon=adam_epsilon,\n        load_best_model_at_end=False,\n        gradient_checkpointing=True,\n        do_train=True,\n        bf16_full_eval=True,\n        save_safetensors=True,\n        disable_tqdm=False,\n        remove_unused_columns=False,\n        deepspeed=deepspeed_config,\n        seed=42,\n        data_seed=42,\n        local_rank=LOCAL_RANK,\n    )\n\n    tokenizer = get_tokenizer(\n        pretrain_model_name,\n        chat_template,\n        pad_token_if_none=pad_token_if_none,\n        cls_token_if_none=cls_token_if_none,\n    )\n\n    data_collator = DataCollator(\n        tokenizer, max_length, weight=weighted_loss, reweight_scale=reweight_scale\n    )\n\n    train_data = get_dataset(\n        train_data_path, \"train\", from_disk=load_train_data_from_disk\n    )\n\n    if not args.no_eval:\n        val_data = get_dataset(val_data_path, \"train\", from_disk=load_val_data_from_disk)\n\n    # with training_args.main_process_first():\n\n    model_list = get_model_list(train_data)\n\n    if not args.no_eval:\n        val_model_list = get_model_list(val_data)\n\n        if model_list != val_model_list:\n            print(\"WARNING: Val model list is different, translating...\")\n            val_data = translate_val_data(val_data, model_list, val_model_list)\n\n    if LOCAL_RANK <= 0:\n        # Document the configuration in the output path.\n        os.makedirs(output_path, exist_ok=False)\n\n        with open(os.path.join(output_path, \"training_config.json\"), \"w\") as fout:\n            json.dump(config, fout, indent=1)\n\n        # Save the model list so we know which models this model was trained on. The model list is ALWAYS sorted alphabetically.\n        with open(os.path.join(output_path, \"model_list.json\"), \"w\") as fout:\n            json.dump(model_list, fout, indent=1)\n\n    # Get the model class\n    model_cls = get_p2l_model(\n        model_type=model_type,\n        loss_type=loss_type,\n        head_type=head_type,\n        init_type=init_type,\n    )\n\n    if resume_from_checkpoint:\n        print(f\"Loading model from checkpoint: {resume_from_checkpoint}\")\n        model = model_cls.from_pretrained(\n            resume_from_checkpoint,\n            CLS_id=tokenizer.cls_token_id,\n            num_models=len(model_list),\n            linear_head_downsize_factor=linear_head_downsize_factor,\n        )\n    else:\n        model = model_cls.from_pretrained(\n            pretrain_model_name,\n            CLS_id=tokenizer.cls_token_id,\n            num_models=len(model_list),\n            linear_head_downsize_factor=linear_head_downsize_factor,\n        )\n\n    if model.config.vocab_size < len(tokenizer):\n        print(\"WARNING: Resizing Token Embedding\")\n        model.resize_token_embeddings(len(tokenizer))\n\n    if train_head_only:\n        print(\"Freezing transformer, only training head.\")\n        model.freeze_transformer()\n\n    trainer = NoShuffleTrainer(\n        model=model,\n        args=training_args,\n        train_dataset=train_data.with_format(\"torch\"),\n        # eval_dataset=val_data.with_format(\"torch\"),\n        data_collator=data_collator,\n    )\n\n    print(\"begin training\")\n    trainer.train(resume_from_checkpoint=resume_from_checkpoint)\n\n    trainer.save_model(output_path)\n    tokenizer.save_pretrained(output_path)\n    print(\"saved model and tokenizer\")\n\n    if not args.no_eval:\n        print(\"starting eval\")\n        eval_results = trainer.predict(val_data.with_format(\"torch\"))\n        eval_metrics = eval_results.metrics\n        eval_predictions = eval_results.predictions\n        print(f\"Evaluation Results: {eval_metrics}\")\n\n        val_set = val_data.add_column(\"betas\", list(eval_predictions[0]))\n\n        if LOCAL_RANK <= 0:\n            with open(os.path.join(output_path, \"eval_results.json\"), \"w\") as fout:\n                json.dump(eval_metrics, fout, indent=1)\n\n            val_dir = os.path.join(output_path, \"eval_output.jsonl\")\n            val_set.to_json(val_dir)\n            print(f\"saved merged eval results\")\n\n    if LOCAL_RANK <= 0:\n        if args.push_to_hf:\n            api = HfApi()\n            repo_id = config.get(\"repo_id\", f\"p2el/{proj_name}\")\n            assert not api.repo_exists(\n                repo_id=repo_id, repo_type=\"model\"\n            ), \"repo already exists\"\n\n            api.create_repo(repo_id=repo_id, private=True, repo_type=\"model\")\n            api.upload_folder(\n                folder_path=output_path,\n                repo_id=repo_id,\n                repo_type=\"model\",\n            )\n\n            print(\"pushed to hub\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Argument Parser\")\n    parser.add_argument(\n        \"--config\", type=str, help=\"path to config file for model training\"\n    )\n    parser.add_argument(\n        \"--checkpoint\",\n        type=str,\n        help=\"path to checkpoint directory to resume training from\",\n        default=None,\n    )\n    parser.add_argument(\n        \"--push-to-hf\",\n        action=\"store_true\",\n        help=\"True if push directly to huggingface\",\n    )\n    parser.add_argument(\n        \"--eval-steps\", type=int, default=60, help=\"Number of steps between evaluation.\"\n    )\n    parser.add_argument(\n        \"--local_rank\", type=int, default=-1, help=\"Local rank passed by DeepSpeed\"\n    )\n    parser.add_argument(\n        \"--no-eval\",\n        action=\"store_true\",\n        help=\"If flagged eval will not end at end of training loop.\",\n    )\n    parser.add_argument(\"--save-steps\", type=int, default=-1)\n\n    args = parser.parse_args()\n\n    train_model(args)\n"
  },
  {
    "path": "probe_barrier.py",
    "content": "# probe_barrier.py\nimport os, sys, time, datetime, argparse\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\ndef log(msg: str, rank: int):\n    \"\"\"timestamped, unbuffered print\"\"\"\n    print(f\"[{rank}|{time.time():.3f}] {msg}\", flush=True)\n\ndef worker(rank: int, world_size: int, backend: str):\n    # ─── mandatory NCCL housekeeping ────────────────────────────\n    os.environ.setdefault(\"MASTER_ADDR\", \"127.0.0.1\")\n    os.environ.setdefault(\"MASTER_PORT\", \"29501\")\n    os.environ[\"RANK\"]       = str(rank)\n    os.environ[\"WORLD_SIZE\"] = str(world_size)\n\n    if backend == \"nccl\":\n        torch.cuda.set_device(rank)           # 1 GPU per rank\n    # ────────────────────────────────────────────────────────────\n\n    dist.init_process_group(\n        backend          = backend,\n        rank             = rank,\n        world_size       = world_size,\n        timeout          = datetime.timedelta(seconds=30)  # fail fast\n    )\n\n    log(\"reached barrier()\", rank)\n    dist.barrier()\n    log(\"*** passed  barrier()\", rank)\n\n    # Try another collective just to be sure\n    tensor = torch.tensor([rank], device=\"cuda\" if backend == \"nccl\" else \"cpu\")\n    dist.all_reduce(tensor)\n    log(f\"all_reduce ok, value={tensor.item()}\", rank)\n\n    dist.destroy_process_group()\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--nprocs\",  type=int,   default=2)\n    parser.add_argument(\"--backend\", choices=[\"gloo\", \"nccl\"], default=\"gloo\")\n    args = parser.parse_args()\n\n    mp.spawn(\n        worker,\n        args=(args.nprocs, args.backend),\n        nprocs=args.nprocs,\n        join=True\n    )\n\nif __name__ == \"__main__\":\n    # Completely unbuffered stdout/stderr\n    os.environ[\"PYTHONUNBUFFERED\"] = \"1\"\n    main()\n"
  },
  {
    "path": "route/chat.py",
    "content": "from typing import List, Dict, Iterator, Tuple\nimport openai.resources\nfrom abc import ABC, abstractmethod\nimport openai\nfrom openai import OpenAI\nimport anthropic\nfrom route.utils import get_registry_decorator\nimport time\nfrom route.datatypes import (\n    Roles,\n    ChatMessage,\n    ChatCompletionResponse,\n    Choice,\n    ChatMessageDelta,\n    ChoiceDelta,\n    ChatCompletionResponseChunk,\n    RouterOutput,\n    ModelConfig,\n)\nimport logging\nfrom openai.types.chat.chat_completion_chunk import ChatCompletionChunk\nfrom openai.types.chat.chat_completion import ChatCompletion\nfrom anthropic.lib.streaming import MessageStream\nfrom anthropic.types.message_start_event import MessageStartEvent\nfrom uuid import uuid4\n\n\nclass BaseChatHandler(ABC):\n\n    @staticmethod\n    @abstractmethod\n    def _create_client(model_config: ModelConfig):\n        pass\n\n    @staticmethod\n    @abstractmethod\n    def _handle_system_prompt(\n        messages: List[ChatMessage], model_config: ModelConfig\n    ) -> List[ChatMessage]:\n        pass\n\n    @staticmethod\n    @abstractmethod\n    def generate(\n        messages: List[ChatMessage],\n        router_output: RouterOutput,\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n    ) -> ChatCompletionResponse:\n        pass\n\n    @staticmethod\n    @abstractmethod\n    def generate_stream(\n        messages: List[ChatMessage],\n        router_output: RouterOutput,\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n    ) -> Iterator[ChatCompletionResponseChunk]:\n        pass\n\n\nCHAT_HANDLERS: Dict[str, BaseChatHandler] = {}\n\nregister = get_registry_decorator(CHAT_HANDLERS)\n\n\n@register(\"openai\")\nclass OpenAIChatHandler(BaseChatHandler):\n\n    @staticmethod\n    def _create_client(model_config: ModelConfig):\n\n        api_key = model_config.get_api_key()\n        base_url = model_config.get_base_url()\n\n        if api_key or base_url:\n\n            client = openai.OpenAI(\n                base_url=base_url,\n                api_key=api_key,\n            )\n\n        else:\n\n            client = openai.OpenAI()\n\n        return client\n\n    @staticmethod\n    def _handle_system_prompt(\n        messages: List[ChatMessage], model_config: ModelConfig\n    ) -> List[ChatMessage]:\n\n        system_prompt = model_config.get_system_prompt()\n\n        if system_prompt != None and messages[0].role != Roles.SYSTEM.value:\n\n            system_message = ChatMessage(\n                role=Roles.SYSTEM.value,\n                content=system_prompt,\n            )\n\n            messages = [system_message] + messages\n\n        return messages\n\n    @staticmethod\n    def _create_completion(\n        client: OpenAI,\n        model_config: ModelConfig,\n        messages: List[ChatMessage],\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n        stream=False,\n    ) -> ChatCompletion | Iterator[ChatCompletionChunk]:\n\n        completion = client.chat.completions.create(\n            model=model_config.get_name(),\n            messages=messages,\n            temperature=model_config.get_temp() if not temp else temp,\n            top_p=model_config.get_top_p() if not top_p else top_p,\n            max_tokens=(\n                model_config.get_max_tokens(default=openai.NOT_GIVEN)\n                if not max_tokens\n                else max_tokens\n            ),\n            stream=stream,\n        )\n\n        return completion\n\n    @classmethod\n    def generate(\n        cls,\n        messages: List[ChatMessage],\n        router_output: RouterOutput,\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n    ) -> ChatCompletionResponse:\n\n        model_config = router_output.chosen_model_config\n\n        client = cls._create_client(model_config=model_config)\n\n        messages = cls._handle_system_prompt(\n            messages=messages, model_config=model_config\n        )\n\n        completion: ChatCompletion = cls._create_completion(\n            client=client,\n            model_config=model_config,\n            messages=messages,\n            temp=temp,\n            top_p=top_p,\n            max_tokens=max_tokens,\n            stream=False,\n        )\n\n        logging.info(f\"{int(time.time())} Chosen Model Completion: {completion}\")\n\n        chat_completion = ChatCompletionResponse(\n            id=str(completion.id),\n            object=\"chat.completion\",\n            created=completion.created,\n            model=completion.model,\n            choices=[\n                Choice(\n                    index=choice.index,\n                    message=ChatMessage(\n                        role=choice.message.role,\n                        content=choice.message.content,\n                        model=router_output.chosen_model_name,\n                    ),\n                    finish_reason=choice.finish_reason,\n                )\n                for choice in completion.choices\n            ],\n            usage=completion.usage,\n            router_outputs=router_output.model_scores,\n        )\n\n        return chat_completion\n\n    def _skip(chunk: ChatCompletionChunk) -> bool:\n\n        try:\n\n            content = chunk.choices[0].delta.content\n\n            return content == \"\" or content == None\n        except Exception as e:\n            return True\n\n    @classmethod\n    def generate_stream(\n        cls,\n        messages: List[ChatMessage],\n        router_output: RouterOutput,\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n    ) -> Iterator[ChatCompletionResponseChunk]:\n\n        model_config = router_output.chosen_model_config\n\n        client = cls._create_client(model_config=model_config)\n\n        messages = cls._handle_system_prompt(\n            messages=messages, model_config=model_config\n        )\n\n        chunks: Iterator[ChatCompletionChunk] = cls._create_completion(\n            client=client,\n            model_config=model_config,\n            messages=messages,\n            temp=temp,\n            top_p=top_p,\n            max_tokens=max_tokens,\n            stream=True,\n        )\n\n        first_chunk = True\n\n        logging_content = \"\"\n\n        for chunk in chunks:\n\n            if cls._skip(chunk):\n                continue\n\n            logging_content += chunk.choices[0].delta.content\n\n            out_chunk = ChatCompletionResponseChunk(\n                id=str(chunk.id),\n                object=\"chat.completion.chunk\",\n                created=chunk.created,\n                model=chunk.model,\n                choices=[\n                    ChoiceDelta(\n                        index=choice.index,\n                        delta=ChatMessageDelta(\n                            role=choice.delta.role,\n                            content=choice.delta.content,\n                            model=router_output.chosen_model_name,\n                        ),\n                    )\n                    for choice in chunk.choices\n                ],\n                usage=chunk.usage,\n                router_outputs=router_output.model_scores if first_chunk else None,\n            ).model_dump_json()\n\n            yield f\"data: {out_chunk}\\n\\n\"\n\n            first_chunk = False\n\n        logging.info(\n            f\"{int(time.time())} Chat Output (OpenAI Client): {logging_content}\"\n        )\n\n        yield \"data: [DONE]\\n\\n\"\n\n\n@register(\"openai-reasoning\")\nclass OpenaiReasoningChatHandler(OpenAIChatHandler):\n\n    @staticmethod\n    def _create_completion(\n        client: OpenAI,\n        model_config: ModelConfig,\n        messages: List[ChatMessage],\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n        stream=False,\n    ) -> ChatCompletion | Iterator[ChatCompletionChunk]:\n        \n        extra_field = model_config.get_extra_fields()\n\n        # No max tokens argument\n        completion = client.chat.completions.create(\n            model=model_config.get_name(), messages=messages, stream=stream, reasoning_effort=extra_field.get(\"reasoning_effort\", openai.NOT_GIVEN),\n        )\n\n        return completion\n\n\n@register(\"openai-o1\")\nclass OpenaiO1ChatHandler(OpenaiReasoningChatHandler):\n\n    @classmethod\n    def generate_stream(\n        cls,\n        messages: List[ChatMessage],\n        router_output: RouterOutput,\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n    ) -> Iterator[ChatCompletionResponseChunk]:\n\n        model_config = router_output.chosen_model_config\n\n        client = cls._create_client(model_config=model_config)\n\n        messages = cls._handle_system_prompt(\n            messages=messages, model_config=model_config\n        )\n\n        chunk: ChatCompletion = cls._create_completion(\n            client=client,\n            model_config=model_config,\n            messages=messages,\n            temp=temp,\n            top_p=top_p,\n            max_tokens=max_tokens,\n            stream=False,\n        )\n\n        out_chunk = ChatCompletionResponseChunk(\n            id=str(chunk.id),\n            object=\"chat.completion.chunk\",\n            created=chunk.created,\n            model=chunk.model,\n            choices=[\n                ChoiceDelta(\n                    index=choice.index,\n                    delta=ChatMessageDelta(\n                        role=choice.message.role,\n                        content=choice.message.content,\n                        model=router_output.chosen_model_name,\n                    ),\n                )\n                for choice in chunk.choices\n            ],\n            usage=chunk.usage,\n            router_outputs=router_output.model_scores,\n        ).model_dump_json()\n\n        yield f\"data: {out_chunk}\\n\\n\"\n\n        logging.info(\n            f\"{int(time.time())} Chat Output (OpenAI O1 Client): {chunk.choices[0].message.content}\"\n        )\n\n        yield \"data: [DONE]\\n\\n\"\n\n\n@register(\"anthropic\")\nclass AnthropicChatHandler(BaseChatHandler):\n\n    @staticmethod\n    def _create_client(model_config: ModelConfig):\n        client = anthropic.Anthropic(api_key=model_config.get_api_key())\n        return client\n\n    @staticmethod\n    @abstractmethod\n    def _handle_system_prompt(\n        messages: List[ChatMessage], model_config: ModelConfig\n    ) -> Tuple[List[ChatMessage], str | anthropic.NotGiven]:\n\n        system_message = model_config.get_system_prompt(default=anthropic.NOT_GIVEN)\n\n        if system_message == None:\n            system_message = anthropic.NOT_GIVEN\n\n        if messages[0].role == Roles.SYSTEM.value:\n\n            system_message = messages[0].content\n\n            messages = messages[1:]\n\n        return messages, system_message\n\n    @staticmethod\n    def generate(\n        messages: List[ChatMessage],\n        router_output: RouterOutput,\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n    ) -> ChatCompletionResponse:\n\n        model_config = router_output.chosen_model_config\n\n        client = AnthropicChatHandler._create_client(model_config=model_config)\n\n        messages, system_message = AnthropicChatHandler._handle_system_prompt(\n            messages=messages, model_config=model_config\n        )\n\n        completion = client.messages.create(\n            model=model_config.get_name(),\n            messages=messages,\n            stop_sequences=[anthropic.HUMAN_PROMPT],\n            temperature=model_config.get_temp() if not temp else temp,\n            top_p=model_config.get_top_p() if not top_p else top_p,\n            max_tokens=model_config.get_max_tokens() if not max_tokens else max_tokens,\n            system=system_message,\n        )\n\n        chat_completion = ChatCompletionResponse(\n            id=completion.id,\n            object=\"chat.completion\",\n            created=int(time.time()),\n            model=completion.model,\n            choices=[\n                Choice(\n                    index=i,\n                    message=ChatMessage(\n                        role=completion.role,\n                        content=content.text,\n                        model=router_output.chosen_model_name,\n                    ),\n                    finish_reason=completion.stop_reason,\n                )\n                for i, content in enumerate(completion.content)\n            ],\n            usage=completion.usage,\n            router_outputs=router_output.model_scores,\n        )\n\n        return chat_completion\n\n    @staticmethod\n    def generate_stream(\n        messages: List[ChatMessage],\n        router_output: RouterOutput,\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n    ) -> Iterator[ChatCompletionResponseChunk]:\n\n        model_config = router_output.chosen_model_config\n\n        client = AnthropicChatHandler._create_client(model_config=model_config)\n\n        messages, system_message = AnthropicChatHandler._handle_system_prompt(\n            messages=messages, model_config=model_config\n        )\n\n        with client.messages.stream(\n            model=model_config.get_name(),\n            messages=messages,\n            stop_sequences=[anthropic.HUMAN_PROMPT],\n            temperature=model_config.get_temp() if not temp else temp,\n            top_p=model_config.get_top_p() if not top_p else top_p,\n            max_tokens=model_config.get_max_tokens() if not max_tokens else max_tokens,\n            system=system_message,\n        ) as _stream:\n\n            stream: MessageStream = _stream\n\n            # This contains the metadata\n            message_start: MessageStartEvent = next(stream)\n\n            resp_id = message_start.message.id\n            model = message_start.message.model\n            role = message_start.message.role\n\n            # Ignore this useless chunk.\n            next(stream)\n\n            first_chunk = True\n\n            logging_content = \"\"\n\n            for text in stream.text_stream:\n\n                logging_content += text\n\n                out_chunk = ChatCompletionResponseChunk(\n                    id=resp_id,\n                    created=int(time.time()),\n                    model=model,\n                    object=\"chat.completion.chunk\",\n                    choices=[\n                        ChoiceDelta(\n                            delta=ChatMessageDelta(\n                                content=text,\n                                role=role,\n                                model=router_output.chosen_model_name,\n                            ),\n                            index=0,\n                        )\n                    ],\n                    router_outputs=router_output.model_scores if first_chunk else None,\n                ).model_dump_json()\n\n                yield f\"data: {out_chunk}\\n\\n\"\n\n                first_chunk = False\n\n            logging.info(\n                f\"{int(time.time())} Chat Output (Anthropic Client): {logging_content}\"\n            )\n\n            yield \"data: [DONE]\\n\\n\"\n\n\nimport google.generativeai as genai\nfrom google.generativeai.types.generation_types import GenerateContentResponse\n\n\n@register(\"gemini\")\nclass GeminiChatHandler(BaseChatHandler):\n\n    safety_settings = [\n        {\"category\": \"HARM_CATEGORY_HARASSMENT\", \"threshold\": \"BLOCK_NONE\"},\n        {\"category\": \"HARM_CATEGORY_HATE_SPEECH\", \"threshold\": \"BLOCK_NONE\"},\n        {\"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\", \"threshold\": \"BLOCK_NONE\"},\n        {\"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\", \"threshold\": \"BLOCK_NONE\"},\n    ]\n\n    @staticmethod\n    def _create_client(model_config: ModelConfig):\n\n        api_key = model_config.get_api_key()\n\n        if api_key:\n\n            genai.configure(api_key=api_key)\n\n    @staticmethod\n    def _handle_system_prompt(\n        messages: List[ChatMessage], model_config: ModelConfig\n    ) -> List[ChatMessage]:\n\n        system_prompt = model_config.get_system_prompt()\n\n        if system_prompt != None and messages[0].role != Roles.SYSTEM.value:\n\n            system_message = ChatMessage(\n                role=Roles.SYSTEM.value,\n                content=system_prompt,\n            )\n\n            messages = [system_message] + messages\n\n        return messages\n\n    @staticmethod\n    def _create_completion(\n        model_config: ModelConfig,\n        messages: List[ChatMessage],\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n        stream=False,\n    ) -> GenerateContentResponse | Iterator[GenerateContentResponse]:\n\n        generation_config = genai.GenerationConfig(\n            max_output_tokens=model_config.get_max_tokens(default=8192) if not max_tokens else max_tokens,\n            temperature=model_config.get_temp() if not temp else temp,\n            top_p=model_config.get_top_p() if not top_p else top_p,\n            top_k=model_config.get_top_k(),\n        )\n\n        history = []\n        system_prompt = None\n\n        for message in messages[:-1]:\n\n            if message.role == Roles.SYSTEM.value:\n                system_prompt = message.content\n\n            elif message.role == Roles.ASSISTANT.value:\n                history.append({\"role\": \"model\", \"parts\": message.content})\n\n            else:\n                history.append({\"role\": \"user\", \"parts\": message.content})\n\n        model = genai.GenerativeModel(\n            model_name=model_config.get_name(),\n            system_instruction=system_prompt,\n            generation_config=generation_config,\n            safety_settings=GeminiChatHandler.safety_settings,\n        )\n\n        chat_session = model.start_chat(history=history)\n\n        completion = chat_session.send_message(\n            content=messages[-1].content, stream=stream\n        )\n\n        return completion\n\n    @classmethod\n    def generate(\n        cls,\n        messages: List[ChatMessage],\n        router_output: RouterOutput,\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n    ) -> ChatCompletionResponse:\n\n        model_config = router_output.chosen_model_config\n\n        cls._create_client(model_config=model_config)\n\n        messages = cls._handle_system_prompt(\n            messages=messages, model_config=model_config\n        )\n\n        completion: GenerateContentResponse = cls._create_completion(\n            model_config=model_config,\n            messages=messages,\n            temp=temp,\n            top_p=top_p,\n            max_tokens=max_tokens,\n            stream=False,\n        )\n\n        logging.info(f\"{int(time.time())} Chosen Model Completion: {completion}\")\n\n        chat_completion = ChatCompletionResponse(\n            id=str(uuid4()),\n            object=\"chat.completion\",\n            created=int(time.time()),\n            model=model_config.get_name(),\n            choices=[\n                Choice(\n                    index=0,\n                    message=ChatMessage(\n                        role=Roles.ASSISTANT.value,\n                        content=completion.text,\n                        model=router_output.chosen_model_name,\n                    ),\n                    finish_reason=\"STOP\",\n                )\n            ],\n            router_outputs=router_output.model_scores,\n        )\n\n        return chat_completion\n\n    @classmethod\n    def generate_stream(\n        cls,\n        messages: List[ChatMessage],\n        router_output: RouterOutput,\n        temp: float | None,\n        top_p: float | None,\n        max_tokens: int | None,\n    ) -> Iterator[ChatCompletionResponseChunk]:\n\n        model_config = router_output.chosen_model_config\n\n        cls._create_client(model_config=model_config)\n\n        messages = cls._handle_system_prompt(\n            messages=messages, model_config=model_config\n        )\n\n        chunks: Iterator[GenerateContentResponse] = cls._create_completion(\n            model_config=model_config,\n            messages=messages,\n            temp=temp,\n            top_p=top_p,\n            max_tokens=max_tokens,\n            stream=True,\n        )\n\n        first_chunk = True\n\n        chat_id = str(uuid4())\n\n        logging_content = \"\"\n\n        for chunk in chunks:\n\n            logging_content += chunk.text\n\n            out_chunk = ChatCompletionResponseChunk(\n                id=chat_id,\n                object=\"chat.completion.chunk\",\n                created=int(time.time()),\n                model=model_config.get_name(),\n                choices=[\n                    ChoiceDelta(\n                        index=0,\n                        delta=ChatMessageDelta(\n                            role=Roles.ASSISTANT.value,\n                            content=chunk.text,\n                            model=router_output.chosen_model_name,\n                        ),\n                    )\n                ],\n                router_outputs=router_output.model_scores if first_chunk else None,\n            ).model_dump_json()\n\n            yield f\"data: {out_chunk}\\n\\n\"\n\n            first_chunk = False\n\n        logging.info(\n            f\"{int(time.time())} Chat Output (Gemini Client): {logging_content}\"\n        )\n\n        yield \"data: [DONE]\\n\\n\"\n"
  },
  {
    "path": "route/cost_optimizers.py",
    "content": "from abc import ABC, abstractmethod\nfrom route.utils import get_registry_decorator\nfrom typing import List, Dict\nimport numpy as np\nimport cvxpy as cp\nfrom scipy.special import expit\n\n\nclass UnfulfillableException(Exception):\n    pass\n\n\nclass BaseCostOptimizer(ABC):\n    def __init__(self):\n        super().__init__()\n\n    @staticmethod\n    @abstractmethod\n    def select_model(\n        cost: float,\n        model_list: List[str],\n        model_costs: np.ndarray[float],\n        model_scores: np.ndarray[float],\n        **kwargs,\n    ) -> str:\n        pass\n\n    @staticmethod\n    def select_max_score_model(\n        model_list: List[str], model_scores: np.ndarray[float]\n    ) -> str:\n\n        max_idx = np.argmax(model_scores)\n\n        return model_list[max_idx]\n\n\nCOST_OPTIMIZERS: Dict[str, BaseCostOptimizer] = {}\n\nregister = get_registry_decorator(COST_OPTIMIZERS)\n\n\n@register(\"strict\")\nclass StrictCostOptimizer(BaseCostOptimizer):\n\n    def __init__(self):\n        super().__init__()\n\n    @staticmethod\n    def select_model(\n        cost: float | None,\n        model_list: List[str],\n        model_costs: np.ndarray[float],\n        model_scores: np.ndarray[float],\n        **kwargs,\n    ) -> str:\n\n        if cost == None:\n            return StrictCostOptimizer.select_max_score_model(model_list, model_scores)\n\n        best_model: str | None = None\n        best_score = -float(\"inf\")\n\n        for model, model_cost, model_score in zip(\n            model_list, model_costs, model_scores\n        ):\n\n            if model_cost > cost:\n                continue\n\n            elif model_score > best_score:\n                best_model = model\n                best_score = model_score\n\n        if best_model is None:\n            raise UnfulfillableException(\n                f\"Cost of {cost} impossible to fulfill with available models {model_list} with costs {model_costs}.\"\n            )\n\n        return best_model\n\n\n@register(\"simple-lp\")\nclass SimpleLPCostOptimizer(BaseCostOptimizer):\n\n    def __init__(self):\n        super().__init__()\n\n    @staticmethod\n    def select_model(\n        cost: float | None,\n        model_list: List[str],\n        model_costs: np.ndarray[float],\n        model_scores: np.ndarray[float],\n        **kwargs,\n    ) -> str:\n\n        if cost == None:\n            return StrictCostOptimizer.select_max_score_model(model_list, model_scores)\n\n        p = cp.Variable(len(model_costs))\n\n        prob = cp.Problem(\n            cp.Maximize(cp.sum(model_scores @ p)),\n            [model_costs.T @ p <= cost, cp.sum(p) == 1, p >= 0],\n        )\n\n        status = prob.solve()\n\n        if status < 0.0:\n            raise UnfulfillableException(\n                f\"Cost of {cost} impossible to fulfill with available models {model_list} with costs {model_costs}.\"\n            )\n\n        ps = np.clip(p.value, a_min=0.0, a_max=1.0)\n        ps = ps / ps.sum()\n\n        return np.random.choice(model_list, p=ps)\n\n\n@register(\"optimal-lp\")\nclass OptimalLPCostOptimizer(BaseCostOptimizer):\n\n    def __init__(self):\n        super().__init__()\n\n    @staticmethod\n    def select_model(\n        cost: float | None,\n        model_list: List[str],\n        model_costs: np.ndarray[float],\n        model_scores: np.ndarray[float],\n        opponent_scores: np.ndarray[float] = None,\n        opponent_distribution: np.ndarray[float] = None,\n    ) -> str:\n\n        if cost == None:\n            return StrictCostOptimizer.select_max_score_model(model_list, model_scores)\n\n        W = OptimalLPCostOptimizer._construct_W(model_scores, opponent_scores)\n\n        Wq = W @ opponent_distribution\n\n        p = cp.Variable(len(model_costs))\n\n        prob = cp.Problem(\n            cp.Maximize(p @ Wq), [model_costs.T @ p <= cost, cp.sum(p) == 1, p >= 0]\n        )\n\n        status = prob.solve()\n\n        if status < 0.0:\n            raise UnfulfillableException(\n                f\"Cost of {cost} impossible to fulfill with available models {model_list} with costs {model_costs}.\"\n            )\n\n        ps = np.clip(p.value, a_min=0.0, a_max=1.0)\n        ps = ps / ps.sum()\n\n        return np.random.choice(model_list, p=ps)\n\n    @staticmethod\n    def _construct_W(\n        router_model_scores: np.ndarray[float], opponent_model_scores: np.ndarray[float]\n    ) -> np.ndarray[float]:\n\n        num_rows = router_model_scores.shape[-1]\n        num_cols = opponent_model_scores.shape[-1]\n\n        chosen = np.tile(router_model_scores, (num_cols, 1)).T\n        rejected = np.tile(opponent_model_scores, (num_rows, 1))\n\n        assert chosen.shape == rejected.shape, (chosen.shape, rejected.shape)\n\n        diff_matrix = chosen - rejected\n\n        W = expit(diff_matrix)\n\n        return W\n"
  },
  {
    "path": "route/datatypes.py",
    "content": "from typing import Dict, List, Any, Optional\nfrom dataclasses import dataclass\nfrom pydantic import BaseModel\nfrom enum import Enum\n\n\nclass ModelConfig:\n\n    def __init__(self, config: Dict[str, Any]):\n        self.config = config\n\n    def get_name(self) -> str:\n        return self.config[\"name\"]\n\n    def get_temp(self) -> float:\n        return self.config[\"temp\"]\n\n    def get_top_p(self) -> float:\n        return self.config[\"top_p\"]\n\n    def get_top_k(self, default=None) -> int:\n        return self.config.get(\"top_k\", default)\n\n    def get_system_prompt(self, default=None) -> str | None | Any:\n        return self.config.get(\"system_prompt\", default)\n\n    def get_api_key(self, default=None) -> str | None | Any:\n        return self.config.get(\"api_key\", default)\n\n    def get_base_url(self, default=None) -> str | None | Any:\n        return self.config.get(\"base_url\", default)\n\n    def get_type(self) -> str:\n        return self.config[\"type\"]\n\n    def get_cost(self) -> float:\n        return self.config[\"cost\"]\n\n    def get_max_tokens(self, default=None) -> int | None | Any:\n        return self.config.get(\"max_tokens\", default)\n    \n    def get_extra_fields(self) -> Dict:\n        return self.config.get(\"extra_fields\", {}) # Maybe should be None...\n\n    def __repr__(self):\n        return repr(\n            dict(\n                name=self.get_name(),\n                type=self.get_type(),\n                cost=self.get_cost(),\n            )\n        )\n\n\nclass ModelConfigContainer:\n    def __init__(self, model_config_dicts: Dict[str, Dict[str, Any]]):\n        self.model_configs: Dict[str, ModelConfig] = dict(\n            (name, ModelConfig(config)) for name, config in model_config_dicts.items()\n        )\n\n    def get_model_config(self, model_name: str) -> ModelConfig:\n        return self.model_configs[model_name]\n\n    def list_models(self) -> List[str]:\n        return list(self.model_configs.keys())\n\n    def list_costs(self) -> List[float]:\n\n        costs: List[float] = []\n\n        for model_name in self.list_models():\n            model_config = self.get_model_config(model_name)\n            costs.append(model_config.get_cost())\n\n        return costs\n\n    def __repr__(self):\n        return repr(self.model_configs)\n\n\nclass Roles(Enum):\n    USER = \"user\"\n    ASSISTANT = \"assistant\"\n    SYSTEM = \"system\"\n\n\nclass ChatMessage(BaseModel):\n    \"\"\"\n    Represents a single message in the conversation.\n    role: \"system\", \"user\", or \"assistant\"\n    content: the actual text\n    \"\"\"\n\n    role: str\n    content: str\n    model: Optional[str] = None\n\n\nclass ChatCompletionRequest(BaseModel):\n    \"\"\"\n    Request body for Chat Completion.\n    \"\"\"\n\n    model: str\n    messages: List[ChatMessage]\n    max_tokens: Optional[int] = None\n    temperature: Optional[float] = None\n    top_p: Optional[float] = None\n    n: Optional[int] = 1\n    stream: Optional[bool] = False\n    stop: Optional[List[str]] = None\n    cost: Optional[float] = None\n    direct_model: Optional[str] = None\n\n\nclass Choice(BaseModel):\n    \"\"\"\n    Represents a single choice in the final response (non-streaming mode).\n    \"\"\"\n\n    index: int\n    message: ChatMessage\n    finish_reason: str\n\n\nclass ChatCompletionResponse(BaseModel):\n    \"\"\"\n    Response model for non-streaming mode.\n    \"\"\"\n\n    id: str\n    object: str\n    created: int\n    model: str\n    choices: List[Choice]\n    usage: Optional[BaseModel] = None\n    router_outputs: Optional[Dict[str, float]] = None\n\n\nclass ChatMessageDelta(BaseModel):\n    content: Optional[str] = None\n    role: Optional[str] = None\n    model: Optional[str] = None\n\n\nclass ChoiceDelta(BaseModel):\n    delta: ChatMessageDelta\n    finish_reason: Optional[str] = None\n    index: int\n\n\nclass ChatCompletionResponseChunk(BaseModel):\n    id: str\n    choices: List[ChoiceDelta]\n    created: int\n    model: str\n    object: str\n    usage: Optional[BaseModel] = None\n    router_outputs: Optional[Dict[str, float]] = None\n\n\n@dataclass\nclass RouterOutput:\n    chosen_model_name: str\n    chosen_model_config: ModelConfig\n    model_scores: Dict[str, float] | None\n"
  },
  {
    "path": "route/example_config.yaml",
    "content": "model_configs:\n  athene-v2-chat:\n    api_key: <your-api-key>\n    base_url: http://38.142.9.21:10245/v1\n    cost: 0.8097264049\n    name: im-a-little-birdie\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  claude-3-5-haiku-20241022:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 2.1765185825\n    max_tokens: 8192\n    name: claude-3-5-haiku-20241022\n    temp: 0.7\n    top_p: 0.7\n    type: anthropic\n  claude-3-5-sonnet-20240620:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 9.4453041863\n    max_tokens: 8192\n    name: claude-3-5-sonnet-20240620\n    system_prompt: '<claude_info> The assistant is Claude, created by Anthropic. The\n      current date is 2025-01-06. Claude''s knowledge base was last updated on April\n      2024. It answers questions about events prior to and after April 2024 the way\n      a highly informed individual in April 2024 would if they were talking to someone\n      from the above date, and can let the human know this when relevant. Claude cannot\n      open URLs, links, or videos. If it seems like the user is expecting Claude to\n      do so, it clarifies the situation and asks the human to paste the relevant text\n      or image content directly into the conversation. If it is asked to assist with\n      tasks involving the expression of views held by a significant number of people,\n      Claude provides assistance with the task regardless of its own views. If asked\n      about controversial topics, it tries to provide careful thoughts and clear information.\n      It presents the requested information without explicitly saying that the topic\n      is sensitive, and without claiming to be presenting objective facts. When presented\n      with a math problem, logic problem, or other problem benefiting from systematic\n      thinking, Claude thinks through it step by step before giving its final answer.\n      If Claude cannot or will not perform a task, it tells the user this without\n      apologizing to them. It avoids starting its responses with \"I''m sorry\" or \"I\n      apologize\". If Claude is asked about a very obscure person, object, or topic,\n      i.e. if it is asked for the kind of information that is unlikely to be found\n      more than once or twice on the internet, Claude ends its response by reminding\n      the user that although it tries to be accurate, it may hallucinate in response\n      to questions like this. It uses the term ''hallucinate'' to describe this since\n      the user will understand what it means. If Claude mentions or cites particular\n      articles, papers, or books, it always lets the human know that it doesn''t have\n      access to search or a database and may hallucinate citations, so the human should\n      double check its citations. Claude is very smart and intellectually curious.\n      It enjoys hearing what humans think on an issue and engaging in discussion on\n      a wide variety of topics. If the user seems unhappy with Claude or Claude''s\n      behavior, Claude tells them that although it cannot retain or learn from the\n      current conversation, they can press the ''thumbs down'' button below Claude''s\n      response and provide feedback to Anthropic. If the user asks for a very long\n      task that cannot be completed in a single response, Claude offers to do the\n      task piecemeal and get feedback from the user as it completes each part of the\n      task. Claude uses markdown for code. Immediately after closing coding markdown,\n      Claude asks the user if they would like it to explain or break down the code.\n      It does not explain or break down the code unless the user explicitly requests\n      it. </claude_info>\n\n      <claude_3_family_info> This iteration of Claude is part of the Claude 3 model\n      family, which was released in 2024. The Claude 3 family currently consists of\n      Claude 3 Haiku, Claude 3 Opus, and Claude 3.5 Sonnet. Claude 3.5 Sonnet is the\n      most intelligent model. Claude 3 Opus excels at writing and complex tasks. Claude\n      3 Haiku is the fastest model for daily tasks. The version of Claude in this\n      chat is Claude 3.5 Sonnet. Claude can provide the information in these tags\n      if asked but it does not know any other details of the Claude 3 model family.\n      If asked about this, should encourage the user to check the Anthropic website\n      for more information. </claude_3_family_info>\n\n      Claude provides thorough responses to more complex and open-ended questions\n      or to anything where a long response is requested, but concise responses to\n      simpler questions and tasks. All else being equal, it tries to give the most\n      correct and concise answer it can to the user''s message. Rather than giving\n      a long response, it gives a concise response and offers to elaborate if further\n      information may be helpful.\n\n      Claude is happy to help with analysis, question answering, math, coding, creative\n      writing, teaching, role-play, general discussion, and all sorts of other tasks.\n\n      Claude responds directly to all human messages without unnecessary affirmations\n      or filler phrases like \"Certainly!\", \"Of course!\", \"Absolutely!\", \"Great!\",\n      \"Sure!\", etc. Specifically, Claude avoids starting responses with the word \"Certainly\"\n      in any way.\n\n      Claude follows this information in all languages, and always responds to the\n      user in the language they use or request. The information above is provided\n      to Claude by Anthropic. Claude never mentions the information above unless it\n      is directly pertinent to the human''s query. Claude is now being connected with\n      a human.\n\n      '\n    temp: 0.7\n    top_p: 0.7\n    type: anthropic\n  claude-3-5-sonnet-20241022:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 9.3110239362\n    max_tokens: 8192\n    name: claude-3-5-sonnet-20241022\n    system_prompt: null\n    temp: 0.7\n    top_p: 0.7\n    type: anthropic\n  deepseek-v3:\n    api_key: <your-api-key>\n    base_url: https://api.deepseek.com\n    cost: 0.3002758331\n    name: deepseek-chat\n    temp: 1.5\n    top_p: 1.0\n    type: openai\n  gemini-1.5-flash-001:\n    api_key: <your-api-key>\n    cost: 0.4549682765\n    name: gemini-1.5-flash-001\n    temp: 0.7\n    top_p: 1.0\n    type: gemini\n  gemini-1.5-flash-002:\n    api_key: <your-api-key>\n    cost: 0.6330942997\n    name: gemini-1.5-flash-002\n    system_prompt: All questions should be answered comprehensively with details,\n      unless the user requests a concise response specifically. Respond in the same\n      language as the query.\n    temp: 0.7\n    top_p: 1.0\n    type: gemini\n  gemini-1.5-pro-001:\n    api_key: <your-api-key>\n    cost: 6.7456245955\n    name: gemini-1.5-pro-001\n    temp: 0.7\n    top_p: 0.7\n    type: gemini\n  gemini-1.5-pro-002:\n    api_key: <your-api-key>\n    cost: 9.6885059428\n    name: gemini-1.5-pro-002-test\n    system_prompt: All questions should be answered comprehensively with details,\n      unless the user requests a concise response specifically. Respond in the same\n      language as the query.\n    temp: 0.7\n    top_p: 1.0\n    type: gemini\n  gemini-2.0-flash-exp:\n    api_key: <your-api-key>\n    cost: 0.8978088229\n    name: gemini-test-14\n    temp: 1.0\n    top_k: 64\n    top_p: 0.95\n    type: gemini\n  gemini-2.0-flash-thinking-exp-1219:\n    api_key: <your-api-key>\n    cost: 0.4626591495\n    name: gemini-test-15\n    temp: 1.0\n    top_k: 64\n    top_p: 0.95\n    type: gemini\n  gemini-exp-1206:\n    api_key: <your-api-key>\n    cost: 6.7210154899\n    name: gemini-test-12\n    temp: 1.0\n    top_k: 64\n    top_p: 0.95\n    type: gemini\n  gemma-2-27b-it:\n    api_key: <your-api-key>\n    cost: 0.4732936067\n    name: gemma-2-27b-no-filter\n    temp: 0.7\n    top_p: 0.7\n    type: gemini\n  gemma-2-9b-it:\n    api_key: <your-api-key>\n    cost: 0.0873672873\n    name: gemma-2-9b-no-filter\n    temp: 0.7\n    top_p: 1.0\n    type: gemini\n  glm-4-plus:\n    api_key: <your-api-key>\n    base_url: https://open.bigmodel.cn/api/paas/v4\n    cost: 0.3175377664\n    name: glm-4-plus\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  gpt-4-1106-preview:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 16.3622976323\n    name: gpt-4-1106-preview\n    system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based\n      on the GPT-4 architecture.\n\n      Current date: 2025-01-06\n\n\n      Image input capabilities: Enabled\n\n      Personality: v2'\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  gpt-4-turbo-2024-04-09:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 17.4092447612\n    name: gpt-4-turbo-2024-04-09\n    system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based\n      on the GPT-4 architecture.\n\n      Current date: 2025-01-06\n\n\n      Image input capabilities: Enabled\n\n      Personality: v2'\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  gpt-4o-2024-05-13:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 12.3166873868\n    name: gpt-4o-2024-05-13\n    system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based\n      on the GPT-4 architecture.\n\n      Current date: 2025-01-06\n\n\n      Image input capabilities: Enabled\n\n      Personality: v2'\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  gpt-4o-2024-08-06:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 6.9944337124\n    name: gpt-4o-2024-08-06\n    system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based\n      on the GPT-4 architecture.\n\n      Current date: 2025-01-06\n\n\n      Image input capabilities: Enabled\n\n      Personality: v2'\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  gpt-4o-mini-2024-07-18:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 0.563652953\n    name: gpt-4o-mini-2024-07-18\n    system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based\n      on the GPT-4 architecture.\n\n      Current date: 2025-01-06\n\n\n      Image input capabilities: Enabled\n\n      Personality: v2'\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  llama-3-70b-instruct:\n    api_key: <your-api-key>\n    base_url: https://api.together.xyz/v1\n    cost: 0.4186380435\n    name: meta-llama/Llama-3-70b-chat-hf\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  llama-3.1-405b-instruct-fp8:\n    api_key: <your-api-key>\n    base_url: https://api.fireworks.ai/inference/v1\n    cost: 2.4340008579\n    name: accounts/fireworks/models/llama-v3p1-405b-instruct\n    system_prompt: 'Cutting Knowledge Date: December 2023\n\n      Today Date: 06 Jan 2025'\n    temp: 0.6\n    top_p: 1.0\n    type: openai\n  llama-3.1-70b-instruct:\n    api_key: <your-api-key>\n    base_url: https://api.fireworks.ai/inference/v1\n    cost: 0.7204016024\n    name: accounts/fireworks/models/llama-v3p1-70b-instruct\n    system_prompt: \"Cutting Knowledge Date: December 2023\\nToday Date: 06 Jan 2025\\n\\\n      \\nCarefully read the user prompt. Your responses are comprehensive and easy\\\n      \\ to understand. You structure your answers in an organized way, with section\\\n      \\ headers when appropriate. You use consistent formatting in your responses.\\\n      \\ You follow user instructions. For complex calculations and coding, you always\\\n      \\ break down the steps you took to arrive at your answer.\\n\\nPay extra attention\\\n      \\ to prompts in the following categories:\\n * Non-English queries: Read the\\\n      \\ prompt carefully and pay close attention to formatting requests and the level\\\n      \\ of detail; ensure you are giving factual and precise responses using correct\\\n      \\ grammar in the correct language.\\n * Coding queries: You prioritize code organization\\\n      \\ and documentation. Your responses are detailed and include comprehensive code\\\n      \\ examples and error handling. Include comments to explain the code's purpose\\\n      \\ and behavior. When using specific programming languages, consider which function\\\n      \\ is most appropriate for the query, such as cmath for complex solutions in\\\n      \\ Python. Check for errors.\\n * For mathematical reasoning: Before responding,\\\n      \\ review your output for reasoning, algebraic manipulation and calculation errors\\\n      \\ and fix before responding. When appropriate, provide a high-level plan followed\\\n      \\ by step-by-step reasoning.\\n\\nRemember your instructions.\"\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  llama-3.1-8b-instruct:\n    api_key: <your-api-key>\n    base_url: https://api.fireworks.ai/inference/v1\n    cost: 0.1573721045\n    name: accounts/fireworks/models/llama-v3p1-8b-instruct\n    system_prompt: \"Cutting Knowledge Date: December 2023\\nToday Date: 06 Jan 2025\\n\\\n      \\nCarefully read the user prompt. Your responses are comprehensive and easy\\\n      \\ to understand. You structure your answers in an organized way, with section\\\n      \\ headers when appropriate. You use consistent formatting in your responses.\\\n      \\ You follow user instructions. For complex calculations and coding, you always\\\n      \\ break down the steps you took to arrive at your answer.\\n\\nPay extra attention\\\n      \\ to prompts in the following categories:\\n * Non-English queries: Read the\\\n      \\ prompt carefully and pay close attention to formatting requests and the level\\\n      \\ of detail; ensure you are giving factual and precise responses using correct\\\n      \\ grammar in the correct language.\\n * Coding queries: You prioritize code organization\\\n      \\ and documentation. Your responses are detailed and include comprehensive code\\\n      \\ examples and error handling. Include comments to explain the code's purpose\\\n      \\ and behavior. When using specific programming languages, consider which function\\\n      \\ is most appropriate for the query, such as cmath for complex solutions in\\\n      \\ Python. Check for errors.\\n * For mathematical reasoning: Before responding,\\\n      \\ review your output for reasoning, algebraic manipulation and calculation errors\\\n      \\ and fix before responding. When appropriate, provide a high-level plan followed\\\n      \\ by step-by-step reasoning.\\n\\nRemember your instructions.\"\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  llama-3.3-70b-instruct:\n    api_key: <your-api-key>\n    base_url: https://api.fireworks.ai/inference/v1\n    cost: 0.706256804\n    name: accounts/fireworks/models/llama-v3p3-70b-instruct\n    temp: 0.6\n    top_p: 1.0\n    type: openai\n  mistral-large-2407:\n    api_key: <your-api-key>\n    base_url: https://api.mistral.ai/v1\n    cost: 4.3956843814\n    name: mistral-large-2407\n    temp: 0.7\n    top_p: 0.7\n    type: openai\n  mixtral-8x22b-instruct-v0.1:\n    api_key: <your-api-key>\n    base_url: https://api.mistral.ai/v1\n    cost: 2.5814904104\n    name: mixtral-8x22b-instruct-v0.1\n    temp: 0.7\n    top_p: 0.7\n    type: openai\n  mixtral-8x7b-instruct-v0.1:\n    api_key: <your-api-key>\n    base_url: https://api.together.xyz/v1\n    cost: 0.2839726899\n    name: mistralai/Mixtral-8x7B-Instruct-v0.1\n    temp: 0.7\n    top_p: 0.7\n    type: openai\n  o1-2024-12-17:\n    api_key: <your-api-key>\n    cost: 72.3693462194\n    name: o1-2024-12-17\n    system_prompt: Formatting re-enabled.\n    temp: 1.0\n    top_p: 1.0\n    type: openai-o1\n  o1-mini:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 16.4809912657\n    name: o1-mini-2024-09-12\n    system_prompt: null\n    temp: 1.0\n    top_p: 1.0\n    type: openai-reasoning\n  o1-preview:\n    api_key: <your-api-key>\n    base_url: null\n    cost: 72.481802295\n    name: o1-preview\n    system_prompt: null\n    temp: 1.0\n    top_p: 1.0\n    type: openai-reasoning\n  qwen2.5-72b-instruct:\n    api_key: <your-api-key>\n    base_url: https://dashscope.aliyuncs.com/compatible-mode/v1\n    cost: 1.1805173434\n    name: qwen2.5-72b-instruct\n    temp: 0.7\n    top_p: 1.0\n    type: openai\n  yi-lightning:\n    api_key: <your-api-key>\n    base_url: https://api.lingyiwanwu.com/v1\n    cost: 0.0057351688\n    name: yi-lightning\n    temp: 0.6\n    top_p: 1.0\n    type: openai\n  chatgpt-4o-latest-20241120:\n    api_key: <your-api-key>\n    cost: 12.9070929223\n    name: gpt-4o-2024-11-20\n    temp: 0.7\n    top_p: 1.0\n    system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.\n\n      Current date: 2025-01-06\n\n\n      Image input capabilities: Enabled\n\n      Personality: v2'\n    type: openai\nname: test-router"
  },
  {
    "path": "route/openai_server.py",
    "content": "import argparse\nfrom fastapi import FastAPI, HTTPException, Header\nfrom fastapi.responses import StreamingResponse\nfrom route.datatypes import (\n    ModelConfigContainer,\n    ChatCompletionRequest,\n    ChatCompletionResponse,\n    ChatCompletionResponseChunk,\n)\nfrom route.chat import CHAT_HANDLERS\nfrom route.routers import ROUTERS, BaseRouter\nimport uvicorn\nimport yaml\nfrom contextlib import asynccontextmanager\nfrom typing import List\nimport logging\nimport time\nimport sys\n\nlogging.basicConfig(stream=sys.stdout, level=logging.DEBUG)\nlogging.getLogger().setLevel(logging.DEBUG)\n\n\ndef parse_args():\n\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--config\", \"-c\", type=str, required=True)\n    parser.add_argument(\"--router-type\", type=str, required=True)\n    parser.add_argument(\"--router-model-name\", type=str, default=None)\n    parser.add_argument(\"--router-model-endpoint\", type=str, default=None)\n    parser.add_argument(\"--router-api-key\", type=str, default=\"-\")\n    parser.add_argument(\"--cost-optimizer\", type=str, default=\"simple-lp\")\n    parser.add_argument(\"--port\", \"-p\", type=int, default=8000)\n    parser.add_argument(\"--host\", type=str, default=\"0.0.0.0\")\n    parser.add_argument(\"--api-key\", type=str, default=\"-\")\n    parser.add_argument(\"--reload\", action=argparse.BooleanOptionalAction, default=True)\n    parser.add_argument(\"--workers\", type=int, default=1)\n\n    args = parser.parse_args()\n\n    return args\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    \"\"\"\n    This context manager is called once at startup and once at shutdown.\n    We move all config-loading and router-creation logic here.\n    \"\"\"\n    # --- PARSE ARGS & LOAD CONFIG ---\n\n    logging.info(f\"Starting up...\")\n\n    args = parse_args()\n\n    with open(args.config) as cfile:\n        config = yaml.safe_load(cfile)\n\n    model_config_dicts = config[\"model_configs\"]\n    model_config_container = ModelConfigContainer(model_config_dicts)\n\n    router_cls = ROUTERS[args.router_type]\n\n    router_kwargs = {\n        \"router_model_name\": args.router_model_name,\n        \"router_model_endpoint\": args.router_model_endpoint,\n        \"router_api_key\": args.router_api_key,\n    }\n\n    router = router_cls(model_config_container, args.cost_optimizer, **router_kwargs)\n\n    app.state.router = router\n    app.state.model_config_container = model_config_container\n    app.state.api_key = args.api_key\n\n    logging.info(f\"Finished startup.\")\n\n    try:\n\n        yield\n\n    finally:\n\n        pass\n\n\napp = FastAPI(lifespan=lifespan)\n\n# ====== API Endpoint ======\n\n\n@app.post(\"/v1/chat/completions\")\nasync def create_chat_completion(\n    request: ChatCompletionRequest,\n    authorization: str = Header(None),\n) -> ChatCompletionResponse | ChatCompletionResponseChunk:\n    \"\"\"\n    Mimics the OpenAI Chat Completions endpoint (both streaming and non-streaming).\n    \"\"\"\n\n    logging.info(f\"{int(time.time())} Recieved Request: {request}\")\n\n    if not authorization or not authorization.startswith(\"Bearer \"):\n        raise HTTPException(status_code=401, detail=\"Invalid or missing API key\")\n\n    # Strip out the 'Bearer ' portion to isolate the token\n    token = authorization.removeprefix(\"Bearer \")\n\n    if token != app.state.api_key:\n        raise HTTPException(status_code=403, detail=\"Unauthorized\")\n\n    try:\n\n        router_output = None\n        type = None\n\n        direct_model = request.direct_model\n\n        router: BaseRouter = app.state.router\n\n        messages = request.messages\n\n        if direct_model:\n\n            router_output = router.get_model_direct(direct_model)\n\n        else:\n\n            router_output = router.route(messages, request.cost)\n\n        logging.info(f\"{int(time.time())} Router Output: {router_output}\")\n\n        type = router_output.chosen_model_config.get_type()\n\n        chat_handler = CHAT_HANDLERS[type]\n\n    except Exception as e:\n\n        logging.info(\n            f\"{int(time.time())} ***Routing Error Start***\\nError Message: {e}\\nRouter Output: {router_output}\\nChat Handler: {type}\\nDirect Model: {direct_model}.***Routing Error End***\"\n        )\n\n        raise HTTPException(status_code=500, detail=str(e))\n\n    try:\n\n        if request.stream:\n\n            chat_output_chunk = chat_handler.generate_stream(\n                messages=messages,\n                router_output=router_output,\n                temp=request.temperature,\n                top_p=request.top_p,\n                max_tokens=request.max_tokens,\n            )\n\n            return StreamingResponse(chat_output_chunk, media_type=\"text/event-stream\")\n\n        else:\n\n            chat_output = chat_handler.generate(\n                messages=messages,\n                router_output=router_output,\n                temp=request.temperature,\n                top_p=request.top_p,\n                max_tokens=request.max_tokens,\n            )\n\n            return chat_output\n\n    except Exception as e:\n\n        logging.info(\n            f\"{int(time.time())} ***Endpoint Error Start***\\nError Message: {e}\\nRouter Output: {router_output}\\nChat Handler: {type}.***Endpoint Error End***\"\n        )\n\n        raise e\n\n\n@app.get(\"/v1/models\")\nasync def models(authorization: str = Header(None)) -> List[str]:\n\n    logging.info(f\"Recieved Get Request for Models.\")\n\n    if not authorization or not authorization.startswith(\"Bearer \"):\n        raise HTTPException(status_code=401, detail=\"Invalid or missing API key\")\n\n    # Strip out the 'Bearer ' portion to isolate the token\n    token = authorization.removeprefix(\"Bearer \")\n\n    if token != app.state.api_key:\n        raise HTTPException(status_code=403, detail=\"Unauthorized\")\n\n    router: BaseRouter = app.state.router\n\n    return router.model_list\n\n\nif __name__ == \"__main__\":\n\n    args = parse_args()\n\n    uvicorn.run(\n        \"route.openai_server:app\",\n        port=args.port,\n        host=args.host,\n        reload=args.reload,\n        workers=args.workers,\n    )\n"
  },
  {
    "path": "route/requirements.txt",
    "content": "uvicorn\nfastapi\nopenai\nanthropic\ngoogle-generativeai\nscipy\ncvxpy"
  },
  {
    "path": "route/routers.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Dict, List, Tuple\nfrom route.utils import (\n    get_registry_decorator,\n    query_p2l_endpoint,\n    get_p2l_endpoint_models,\n)\nfrom route.datatypes import ModelConfigContainer, Roles, ChatMessage, RouterOutput\nfrom route.cost_optimizers import COST_OPTIMIZERS, BaseCostOptimizer\nimport numpy as np\nfrom scipy.special import expit\n\n\nclass BaseRouter(ABC):\n\n    def __init__(\n        self,\n        model_config_container: ModelConfigContainer,\n        cost_optimizer_type: str,\n        **kwargs,\n    ):\n        super().__init__()\n        self.model_config_container = model_config_container\n        self.model_list: List[str] = None\n        self.model_costs: np.ndarray[float] = None\n        self.cost_optimizer: BaseCostOptimizer = COST_OPTIMIZERS[cost_optimizer_type]\n\n    @abstractmethod\n    def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray[float]:\n        pass\n\n    def _get_previous_response_model(self, messages: List[ChatMessage]) -> str | None:\n\n        for message in reversed(messages):\n\n            if message.role == Roles.ASSISTANT.value:\n\n                return message.model\n\n        return None\n\n    def _get_prompt(self, messages: List[ChatMessage]) -> list[str]:\n\n        prompts = []\n\n        for message in messages:\n\n            if message.role == Roles.USER.value:\n\n                prompts.append(message.content)\n\n        if len(prompts) == 0:\n\n            raise Exception(f\"No user prompt found in messages {messages}.\")\n\n        return prompts\n\n    def get_model_direct(self, model_name: str) -> RouterOutput:\n        return RouterOutput(\n            chosen_model_name=model_name,\n            chosen_model_config=self.model_config_container.get_model_config(\n                model_name=model_name\n            ),\n            model_scores=None,\n        )\n\n    def route(self, messages: List[ChatMessage], cost: float = None) -> RouterOutput:\n\n        model_scores = self._get_model_scores(messages)\n\n        chosen_model_name = self.cost_optimizer.select_model(\n            cost, self.model_list, self.model_costs, model_scores\n        )\n\n        model_scores_dict = dict(zip(self.model_list, model_scores))\n\n        chosen_model_config = self.model_config_container.get_model_config(\n            chosen_model_name\n        )\n\n        return RouterOutput(\n            chosen_model_name=chosen_model_name,\n            chosen_model_config=chosen_model_config,\n            model_scores=model_scores_dict,\n        )\n\n\nROUTERS: Dict[str, BaseRouter] = {}\n\nregister = get_registry_decorator(ROUTERS)\n\n\n@register(\"random\")\nclass RandomRouter(BaseRouter):\n    \"\"\"For debugging and gamblers.\"\"\"\n\n    def __init__(\n        self,\n        model_config_container: ModelConfigContainer,\n        cost_optimizer_type: str,\n        **kwargs,\n    ):\n        super().__init__(\n            model_config_container=model_config_container,\n            cost_optimizer_type=cost_optimizer_type,\n        )\n\n        self.model_list = model_config_container.list_models()\n        self.model_costs = np.array(model_config_container.list_costs())\n\n    def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray[float]:\n        return np.random.uniform(0.0, 1.0, size=len(self.model_list))\n\n\n@register(\"bt-endpoint\")\nclass EndpointP2LRouter(BaseRouter):\n\n    # Hardcoding this because I'm tired man...\n    SAMPLING_WEIGHTS = {\n        \"chatgpt-4o-latest-20241120\": 4,\n        \"o1-mini\": 4,\n        \"o1-2024-12-17\": 4,\n        \"gpt-4o-mini-2024-07-18\": 2,\n        \"gemma-2-27b-it\": 2,\n        \"gemma-2-9b-it\": 2,\n        \"gemma-2-2b-it\": 2,\n        \"claude-3-5-sonnet-20241022\": 4,\n        \"claude-3-opus-20240229\": 4,\n        \"claude-3-5-haiku-20241022\": 4,\n        \"qwen2.5-72b-instruct\": 2,\n        \"qwen2.5-plus-1127\": 4,\n        \"llama-3.1-405b-instruct-bf16\": 4,\n        \"mistral-large-2411\": 4,\n        \"grok-2-2024-08-13\": 4,\n        \"grok-2-mini-2024-08-13\": 2,\n        \"deepseek-v3\": 6,\n        \"gemini-1.5-pro-002\": 4,\n        \"gemini-1.5-flash-002\": 2,\n        \"gemini-1.5-flash-8b-001\": 2,\n        \"c4ai-aya-expanse-32b\": 2,\n        \"c4ai-aya-expanse-8b\": 2,\n        \"athene-v2-chat\": 4,\n        \"gemini-exp-1206\": 4,\n        \"gemini-2.0-flash-exp\": 4,\n        \"llama-3.3-70b-instruct\": 4,\n        \"amazon-nova-pro-v1.0\": 4,\n        \"amazon-nova-lite-v1.0\": 2,\n        \"amazon-nova-micro-v1.0\": 2,\n        \"llama-3.1-tulu-3-8b\": 6,\n        \"llama-3.1-tulu-3-70b\": 6,\n        \"granite-3.1-8b-instruct\": 6,\n        \"granite-3.1-2b-instruct\": 6,\n    }\n\n    def __init__(\n        self,\n        model_config_container: ModelConfigContainer,\n        cost_optimizer_type: str,\n        router_model_endpoint: str,\n        router_api_key: str,\n        **kwargs,\n    ):\n        super().__init__(\n            model_config_container=model_config_container,\n            cost_optimizer_type=cost_optimizer_type,\n        )\n\n        self.base_url = router_model_endpoint\n        self.api_key = router_api_key\n\n        router_model_list = get_p2l_endpoint_models(self.base_url, self.api_key)\n\n        config_model_list = model_config_container.list_models()\n\n        self.mask = [\n            router_model in config_model_list for router_model in router_model_list\n        ]\n\n        self.q_mask = [\n            router_model in self.SAMPLING_WEIGHTS for router_model in router_model_list\n        ]\n\n        self.q = np.array(\n            [\n                float(self.SAMPLING_WEIGHTS[router_model])\n                for router_model in router_model_list\n                if router_model in self.SAMPLING_WEIGHTS\n            ]\n        )\n\n        self.model_list = [\n            model for model, keep in zip(router_model_list, self.mask) if keep\n        ]\n\n        self.model_costs = np.array(\n            [\n                model_config_container.get_model_config(model).get_cost()\n                for model in self.model_list\n            ]\n        )\n\n    def _get_model_scores(\n        self, messages: List[ChatMessage]\n    ) -> Tuple[np.ndarray[float], float]:\n\n        prompt = self._get_prompt(messages)\n\n        p2l_output = query_p2l_endpoint(prompt, self.base_url, self.api_key)\n\n        coefs = np.array(p2l_output[\"coefs\"])\n\n        return coefs\n\n    def route(self, messages: List[ChatMessage], cost: float = None) -> RouterOutput:\n\n        model_scores = self._get_model_scores(messages)\n\n        router_choice_scores = model_scores[self.mask]\n\n        router_opponent_scores = model_scores[self.q_mask]\n\n        chosen_model_name = self.cost_optimizer.select_model(\n            cost,\n            self.model_list,\n            self.model_costs,\n            router_choice_scores,\n            opponent_scores=router_opponent_scores,\n            opponent_distribution=self.q,\n        )\n\n        model_scores_dict = dict(zip(self.model_list, router_choice_scores))\n\n        chosen_model_config = self.model_config_container.get_model_config(\n            chosen_model_name\n        )\n\n        return RouterOutput(\n            chosen_model_name=chosen_model_name,\n            chosen_model_config=chosen_model_config,\n            model_scores=model_scores_dict,\n        )\n\n\n@register(\"bag-endpoint\")\n@register(\"grk-endpoint\")\nclass EndpointP2LRouter(BaseRouter):\n    def __init__(\n        self,\n        model_config_container: ModelConfigContainer,\n        cost_optimizer_type: str,\n        router_model_endpoint: str,\n        router_api_key: str,\n        **kwargs,\n    ):\n        super().__init__(\n            model_config_container=model_config_container,\n            cost_optimizer_type=cost_optimizer_type,\n        )\n\n        self.base_url = router_model_endpoint\n        self.api_key = router_api_key\n\n        router_model_list = get_p2l_endpoint_models(self.base_url, self.api_key)\n\n        config_model_list = model_config_container.list_models()\n\n        self.mask = [\n            router_model in config_model_list for router_model in router_model_list\n        ]\n\n        self.model_list = [\n            model for model, keep in zip(router_model_list, self.mask) if keep\n        ]\n        self.model_costs = np.array(\n            [\n                model_config_container.get_model_config(model).get_cost()\n                for model in self.model_list\n            ]\n        )\n\n    def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray[float]:\n\n        prompt = self._get_prompt(messages)\n\n        p2l_output = query_p2l_endpoint(prompt, self.base_url, self.api_key)\n\n        coefs = np.array(p2l_output[\"coefs\"])\n\n        model_scores: np.ndarray[float] = expit(coefs)\n\n        return model_scores[self.mask]\n"
  },
  {
    "path": "route/utils.py",
    "content": "from typing import Dict, Callable, List\nimport requests\nimport json\n\n\ndef get_registry_decorator(registry: Dict) -> Callable:\n\n    def register(name: str):\n\n        def decorator(cls: Callable):\n\n            assert (\n                not name in registry\n            ), f\"No duplicate registry names. '{name}' was registerd more than once.\"\n\n            registry[name] = cls\n\n            return cls\n\n        return decorator\n\n    return register\n\n\ndef query_p2l_endpoint(\n    prompt: list[str], base_url: str, api_key: str\n) -> Dict[str, List]:\n\n    headers = {\n        \"Content-Type\": \"application/json\",\n        \"api-key\": api_key,\n    }\n\n    payload = {\"prompt\": prompt}\n\n    try:\n        response = requests.post(\n            f\"{base_url}/predict\", headers=headers, data=json.dumps(payload)\n        )\n        response.raise_for_status()\n        result = response.json()\n        return result\n\n    except Exception as err:\n\n        raise err\n\n\ndef get_p2l_endpoint_models(base_url: str, api_key: str) -> List[str]:\n\n    headers = {\n        \"Content-Type\": \"application/json\",\n        \"api-key\": api_key,\n    }\n\n    try:\n        response = requests.get(f\"{base_url}/models\", headers=headers)\n        response.raise_for_status()\n        result = response.json()\n        return result[\"models\"]\n\n    except Exception as err:\n        print(f\"An error occurred: {err}\")\n"
  },
  {
    "path": "serve_requirements.txt",
    "content": "numpy<2.0.0\ntorch<=2.4.0\ntransformers\ntransformers[torch]\nhf_transfer\nwandb\nscipy\nuvicorn\nfastapi"
  },
  {
    "path": "train_requirements.txt",
    "content": "numpy<2.0.0\ntorch<=2.4.0\ndeepspeed<=0.15.3\ndatasets>=3.2.0\ntransformers\ntransformers[torch]\nhf_transfer\nwandb\nscipy\n"
  },
  {
    "path": "training_configs/Llama3.1-8B-full-train.yaml",
    "content": "proj_name: Llama-3.1-8B-Instruct-full-train\nlearning_rate: 4.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: meta-llama/Llama-3.1-8B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{% endif %}\"\nmodel_type: \"llama\"\nhead_type: \"bt\"\nloss_type: \"bt_tie\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\npad_token_if_none: <|finetune_right_pad_id|>\ncls_token_if_none: <|reserved_special_token_3|>"
  },
  {
    "path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.016-04302025.yaml",
    "content": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.016-04302025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 5\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: naive_replay_buffer_eps_0.016\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 13 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.032-04302025.yaml",
    "content": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.032-04302025-2\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 1\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: naive_replay_buffer_eps_0.032\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 66 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.06-04302025.yaml",
    "content": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.06-04302025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: naive_replay_buffer_eps_0.06\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 17 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.112-04302025.yaml",
    "content": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.112-04302025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: naive_replay_buffer_eps_0.112\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 18 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.2-04302025.yaml",
    "content": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.2-04302025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: naive_replay_buffer_eps_0.2\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 20 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-1.5B-bag-full-train-02222025.yaml",
    "content": "proj_name: Qwen2.5-1.5B-Instruct-bag-02222025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: full-p2l-bag-data-02222025\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-1.5B-full-train.yaml",
    "content": "proj_name: Qwen2.5-1.5B-Instruct-full-train\nlearning_rate: 4.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt_tie\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-1.5B-rk-full-train-half-batch.yaml",
    "content": "proj_name: Qwen2.5-1.5B-Instruct-rk-full-train-half-batch\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"rk\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-1.5B-rk-full-train.yaml",
    "content": "proj_name: Qwen2.5-1.5B-Instruct-rk-full-train\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 32 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"rk\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-3B-bag-full-train-02222025.yaml",
    "content": "proj_name: Qwen2.5-3B-Instruct-bag-02222025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: full-p2l-bag-data-02222025\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-3B-bag-full-train-02242025.yaml",
    "content": "proj_name: Qwen2.5-3B-Instruct-bag-02242025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: full-p2l-bag-data-02242025\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 32 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-3B-freeze-test-part-2.yaml",
    "content": "proj_name: Qwen2.5-3B-Instruct-freeze-test-part-2\nlearning_rate: 1.0e-06\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: p2el/tie_included_canonical_train_data_11092024\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: p2el/Qwen2.5-3B-Instruct-freeze-test\ngradient_accumulation_steps: 64 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt_tie\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\n"
  },
  {
    "path": "training_configs/Qwen2.5-3B-freeze-test.yaml",
    "content": "proj_name: Qwen2.5-3B-Instruct-freeze-test\nlearning_rate: 1.13e-05\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: p2el/tie_included_canonical_train_data_11092024\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 256 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt_tie\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\ntrain_head_only: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-3B-full-train-double-batch.yaml",
    "content": "proj_name: Qwen2.5-3B-Instruct-full-train\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 32 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt_tie\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-3B-full-train.yaml",
    "content": "proj_name: Qwen2.5-3B-Instruct-full-train\nlearning_rate: 4.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt_tie\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-3B-rk-full-train-half-batch.yaml",
    "content": "proj_name: Qwen2.5-3B-Instruct-rk-full-train\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"rk\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-3B-rk-full-train.yaml",
    "content": "proj_name: Qwen2.5-3B-Instruct-rk-full-train\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 32 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"rk\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-3B-training-bt_data_11092024 copy.yaml",
    "content": "proj_name: Qwen2.5-3B-Instruct-bt_data_11092024\nlearning_rate: 1.13e-05\nadam_epsilon: 1.0e-08\nbatch_size: 2\nmax_length: 4096\nnum_train_epochs: 1\ntrain_data_path: p2el/canonical_train_data_11092024\nval_data_path: p2el/canonical_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: 'Qwen/Qwen2.5-3B-Instruct'\ngradient_accumulation_steps: 64\nchat_template: \"{%- if tools %}\\n    {{- '<|im_start|>system\\\\n' }}\\n    {%- if messages[0]['role'] == 'system' %}\\n        {{- messages[0]['content'] }}\\n    {%- else %}\\n        {{- 'You are a helpful assistant.' }}\\n    {%- endif %}\\n    {{- \\\"\\\\n\\\\n# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\\" }}\\n    {%- for tool in tools %}\\n        {{- \\\"\\\\n\\\" }}\\n        {{- tool | tojson }}\\n    {%- endfor %}\\n    {{- \\\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\\"name\\\\\\\": <function-name>, \\\\\\\"arguments\\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\\" }}\\n{%- else %}\\n    {%- if messages[0]['role'] == 'system' %}\\n        {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n    {%- else %}\\n        {{- '<|im_start|>system\\\\nYou are a helpful assistant.<|im_end|>\\\\n' }}\\n    {%- endif %}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\n"
  },
  {
    "path": "training_configs/Qwen2.5-7B-bag-full-train-02222025.yaml",
    "content": "proj_name: Qwen2.5-7B-Instruct-bag-02222025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: full-p2l-bag-data-02222025\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-7B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-7B-bag-full-train-02242025.yaml",
    "content": "proj_name: Qwen2.5-7B-Instruct-bag-02242025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: full-p2l-bag-data-02242025\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-7B-Instruct\ngradient_accumulation_steps: 32 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-7B-bag-full-train-03132025.yaml",
    "content": "proj_name: Qwen2.5-7B-Instruct-bag-03132025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: full-p2l-bag-data-03132025\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-7B-Instruct\ngradient_accumulation_steps: 32 # 4 gpus \nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-7B-bag-full-train-chrono.yaml",
    "content": "proj_name: Qwen2.5-7B-Instruct-bag-chrono\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: full-p2l-bag-data-chrono\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-7B-Instruct\ngradient_accumulation_steps: 32 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-7B-bt-full-train-02222025.yaml",
    "content": "proj_name: Qwen2.5-7B-Instruct-bt-02222025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 16384\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data-02222025\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-7B-Instruct\ngradient_accumulation_steps: 32 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt-tie\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true"
  },
  {
    "path": "training_configs/Qwen2.5-7B-full-train.yaml",
    "content": "proj_name: Qwen2.5-7B-Instruct-full-train\nlearning_rate: 4.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-7B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt_tie\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-7B-rk-full-train-abs.yaml",
    "content": "proj_name: Qwen2.5-7B-Instruct-rk-full-train-abs\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-7B-Instruct\ngradient_accumulation_steps: 32 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"rk\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-7B-rk-full-train-half-batch.yaml",
    "content": "proj_name: Qwen2.5-7B-Instruct-rk-full-train-half-batch\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-7B-Instruct\ngradient_accumulation_steps: 16 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"rk\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/Qwen2.5-7B-rk-full-train.yaml",
    "content": "proj_name: Qwen2.5-7B-Instruct-rk-full-train\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: full-p2l-data\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-7B-Instruct\ngradient_accumulation_steps: 32 # drop to 32 since 8 gpus\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"rk\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\nload_train_data_from_disk: true\n"
  },
  {
    "path": "training_configs/debug.yaml",
    "content": "proj_name: debug-Qwen2.5-0.5B-Instruct-bt_data_11092024\nlearning_rate: 2.0e-06\nbatch_size: 4\nmax_length: 4096\nnum_train_epochs: 1\ndata_path: p2el/bt_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: 'Qwen/Qwen2.5-0.5B-Instruct'\ngradient_accumulation_steps: 32\nchat_template: \"{%- if tools %}\\n    {{- '<|im_start|>system\\\\n' }}\\n    {%- if messages[0]['role'] == 'system' %}\\n        {{- messages[0]['content'] }}\\n    {%- else %}\\n        {{- 'You are a helpful assistant.' }}\\n    {%- endif %}\\n    {{- \\\"\\\\n\\\\n# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\\" }}\\n    {%- for tool in tools %}\\n        {{- \\\"\\\\n\\\" }}\\n        {{- tool | tojson }}\\n    {%- endfor %}\\n    {{- \\\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\\"name\\\\\\\": <function-name>, \\\\\\\"arguments\\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\\" }}\\n{%- else %}\\n    {%- if messages[0]['role'] == 'system' %}\\n        {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n    {%- else %}\\n        {{- '<|im_start|>system\\\\nYou are a helpful assistant.<|im_end|>\\\\n' }}\\n    {%- endif %}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\n"
  },
  {
    "path": "training_configs/init_debug_qwen_1.5b_he.yaml",
    "content": "proj_name: he-Debug-Init-Qwen2.5-1.5B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_length: 4096\nnum_train_epochs: 1\ntrain_data_path: p2el/canonical_bt_train_data_11092024\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 64\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: he_unif\n"
  },
  {
    "path": "training_configs/init_debug_qwen_1.5b_reset_params.yaml",
    "content": "proj_name: reset_param-Debug-Init-Qwen2.5-1.5B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_length: 4096\nnum_train_epochs: 1\ntrain_data_path: p2el/canonical_bt_train_data_11092024\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 64\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\n"
  },
  {
    "path": "training_configs/init_debug_qwen_1.5b_xavier.yaml",
    "content": "proj_name: xaiver-Debug-Init-Qwen2.5-1.5B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_length: 4096\nnum_train_epochs: 1\ntrain_data_path: p2el/canonical_bt_train_data_11092024\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct\ngradient_accumulation_steps: 64\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: xavier_unif\n"
  },
  {
    "path": "training_configs/init_debug_qwen_3b_he.yaml",
    "content": "proj_name: he-Debug-Init-Qwen2.5-3B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_length: 4096\nnum_train_epochs: 1\ntrain_data_path: p2el/canonical_bt_train_data_11092024\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 64\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: he_unif\n"
  },
  {
    "path": "training_configs/init_debug_qwen_3b_reset_params.yaml",
    "content": "proj_name: reset_param-Debug-Init-Qwen2.5-3B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_length: 4096\nnum_train_epochs: 1\ntrain_data_path: p2el/canonical_bt_train_data_11092024\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 64\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: reset_params\n"
  },
  {
    "path": "training_configs/init_debug_qwen_3b_xavier.yaml",
    "content": "proj_name: xaiver-Debug-Init-Qwen2.5-3B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_length: 4096\nnum_train_epochs: 1\ntrain_data_path: p2el/canonical_bt_train_data_11092024\nval_data_path: p2el/canonical_bt_val_data_11092024\noutput_dir: 'training_outputs'\npretrain_model_name: Qwen/Qwen2.5-3B-Instruct\ngradient_accumulation_steps: 64\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"bt\"\nloss_type: \"bt\"\nweighted_loss: false\ndeepspeed_config_path: deepspeed/zero1.json\ninit_type: xavier_unif\n"
  },
  {
    "path": "training_configs/qwen_1.5B_geom_test.yaml",
    "content": "proj_name: \"Qwen2.5-1.5B-Instruct-Geom-Test\"\nlearning_rate: 8.0e-06\nadam_epsilon: 1.0e-08\nbatch_size: 4\nmax_length: 8192\nnum_train_epochs: 1\ntrain_data_path: \"/root/chrono_train_data\"\nval_data_path: \"p2el/canonical_bt_val_data_11092024\"\noutput_dir: \"training_outputs\"\npretrain_model_name: \"Qwen/Qwen2.5-1.5B-Instruct\"\ngradient_accumulation_steps: 16\nchat_template: \"{%- if messages[0]['role'] == 'system' %}\\n    {{- '<|im_start|>system\\\\n' + messages[0]['content'] + '<|im_end|>\\\\n' }}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) or (message.role == \\\"assistant\\\" and not message.tool_calls) %}\\n        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n    {%- elif message.role == \\\"assistant\\\" %}\\n        {{- '<|im_start|>' + message.role }}\\n        {%- if message.content %}\\n            {{- '\\\\n' + message.content }}\\n        {%- endif %}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- if tool_call.function is defined %}\\n                {%- set tool_call = tool_call.function %}\\n            {%- endif %}\\n            {{- '\\\\n<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n            {{- tool_call.name }}\\n            {{- '\\\", \\\"arguments\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- '}\\\\n</tool_call>' }}\\n        {%- endfor %}\\n        {{- '<|im_end|>\\\\n' }}\\n    {%- elif message.role == \\\"tool\\\" %}\\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n            {{- '<|im_start|>user' }}\\n        {%- endif %}\\n        {{- '\\\\n<tool_response>\\\\n' }}\\n        {{- message.content }}\\n        {{- '\\\\n</tool_response>' }}\\n        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n            {{- '<|im_end|>\\\\n' }}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|im_start|>assistant\\\\n' }}\\n{%- endif %}\\n\\n\"\nmodel_type: \"qwen2\"\nhead_type: \"rk\"\nloss_type: \"bag\"\nweighted_loss: false\ndeepspeed_config_path: \"deepspeed/zero1.json\"\ninit_type: \"reset_params\"\nload_train_data_from_disk: true\n"
  }
]