Showing preview only (286K chars total). Download the full file or copy to clipboard to get everything.
Repository: lmarena/p2l
Branch: main
Commit: a905fa5ea94a
Files: 59
Total size: 268.0 KB
Directory structure:
gitextract_yhm008o_/
├── .gitignore
├── README.md
├── deepspeed/
│ └── zero1.json
├── fast_lambda_setup.sh
├── fast_runpod_setup.sh
├── p2l/
│ ├── auto_eval_utils.py
│ ├── auto_evals.py
│ ├── dataset.py
│ ├── endpoint.py
│ ├── eval.py
│ ├── model.py
│ └── train.py
├── probe_barrier.py
├── route/
│ ├── chat.py
│ ├── cost_optimizers.py
│ ├── datatypes.py
│ ├── example_config.yaml
│ ├── openai_server.py
│ ├── requirements.txt
│ ├── routers.py
│ └── utils.py
├── serve_requirements.txt
├── train_requirements.txt
└── training_configs/
├── Llama3.1-8B-full-train.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.016-04302025.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.032-04302025.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.06-04302025.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.112-04302025.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.2-04302025.yaml
├── Qwen2.5-1.5B-bag-full-train-02222025.yaml
├── Qwen2.5-1.5B-full-train.yaml
├── Qwen2.5-1.5B-rk-full-train-half-batch.yaml
├── Qwen2.5-1.5B-rk-full-train.yaml
├── Qwen2.5-3B-bag-full-train-02222025.yaml
├── Qwen2.5-3B-bag-full-train-02242025.yaml
├── Qwen2.5-3B-freeze-test-part-2.yaml
├── Qwen2.5-3B-freeze-test.yaml
├── Qwen2.5-3B-full-train-double-batch.yaml
├── Qwen2.5-3B-full-train.yaml
├── Qwen2.5-3B-rk-full-train-half-batch.yaml
├── Qwen2.5-3B-rk-full-train.yaml
├── Qwen2.5-3B-training-bt_data_11092024 copy.yaml
├── Qwen2.5-7B-bag-full-train-02222025.yaml
├── Qwen2.5-7B-bag-full-train-02242025.yaml
├── Qwen2.5-7B-bag-full-train-03132025.yaml
├── Qwen2.5-7B-bag-full-train-chrono.yaml
├── Qwen2.5-7B-bt-full-train-02222025.yaml
├── Qwen2.5-7B-full-train.yaml
├── Qwen2.5-7B-rk-full-train-abs.yaml
├── Qwen2.5-7B-rk-full-train-half-batch.yaml
├── Qwen2.5-7B-rk-full-train.yaml
├── debug.yaml
├── init_debug_qwen_1.5b_he.yaml
├── init_debug_qwen_1.5b_reset_params.yaml
├── init_debug_qwen_1.5b_xavier.yaml
├── init_debug_qwen_3b_he.yaml
├── init_debug_qwen_3b_reset_params.yaml
├── init_debug_qwen_3b_xavier.yaml
└── qwen_1.5B_geom_test.yaml
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
================================================
FILE: README.md
================================================
# Prompt-to-Leaderboard (P2L)
This is the codebase for the paper [Prompt-to-Leaderboard](https://arxiv.org/pdf/2502.14855).
Models weights found at our [LMArena HF Collection](https://huggingface.co/collections/lmarena-ai/prompt-to-leaderboard-67bcf7ddf6022ef3cfd260cc).
Try on Chatbot Arena at the [Prompt-to-Leaderboard](https://lmarena.ai/?p2l) tab!
## Abstract
Large language model (LLM) evaluations typically rely on aggregated metrics like accuracy or human preference, averaging across users and prompts. This averaging obscures user- and prompt-specific variations in model performance.
To address this, we propose Prompt-to-Leaderboard (P2L), a method that produces leaderboards specific to a prompt or set of prompts.
The core idea is to train an LLM taking natural language prompts as input to output a vector of Bradley-Terry coefficients which are then used to predict the human preference vote.
The resulting prompt-dependent leaderboards allow for unsupervised task-specific evaluation, optimal routing of queries to models, personalization, and automated evaluation of model strengths and weaknesses.
Data from Chatbot Arena suggest that P2L better captures the nuanced landscape of language model performance than the averaged leaderboard.
Furthermore, our findings suggest that P2L's ability to produce prompt-specific evaluations follows a power law scaling similar to that observed in LLMs themselves. In January 2025, the router we trained based on this methodology achieved the #1 spot in the Chatbot Arena leaderboard.
## Table of Contents
- [P2L](#p2l)
- [Abstract](#abstract)
- [Table of Contents](#table-of-contents)
- [Environment Setup](#environment-setup)
- [Installing `uv`](#installing-uv)
- [Serving P2L Setup](#serving-p2l-setup)
- [Serving a Router Setup](#serving-a-router-setup)
- [Training Setup](#training-setup)
- [Serving P2L](#serving-p2l)
- [Serving an OpenAI Compatible Router](#serving-an-openai-compatible-router)
- [Example: serving a Bradley-Terry based cost-optimal router](#example-serving-a-bradley-terry-based-cost-optimal-router)
- [Example: serving a Grounded RK based simple cost router](#example-serving-a-grounded-rk-based-simple-cost-router)
- [Calling the OpenAI Compatible Router](#calling-the-openai-compatible-router)
- [Training a P2L Model](#training-a-p2l-model)
- [Inferencing a P2L Model](#inferencing-a-p2l-model)
- [AutoEval Suite](#autoeval-suite)
- [Params](#params)
- [Citation](#citation)
## Environment Setup
Setup instuctions will be shown using `uv`, however any package management system will work. All environments are native to Python 3.10, other versions are untested but may also work.
### Installing `uv`
If you like the sound of ~50x faster environment setup times, run the following to install `uv`.
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
```
To create a Python virtual environment run:
```bash
uv venv .env --python 3.10
```
To activate said environment, run:
```bash
source .env/bin/activate
```
### Serving P2L Setup
To serve a P2L model first run:
```bash
uv pip install -r serve_requirements.txt
```
### Serving a Router Setup
To serve a OpenAI compatible router, first run:
```bash
uv pip install -r route/requirements.txt
```
### Training Setup
To train a P2L model first run:
```bash
uv pip install -r train_requirements.txt
```
## Serving P2L
Before getting started, make sure you have followed the steps in [Serving Setup](#serving-p2l-setup).
`python p2l.endpoint` considers the following arguments:
| Option | Short Flag | Description |
|--------|-----------|-------------|
| `--help` | `-h` | Show this help message and exit. |
| `--model-path MODEL_PATH` | `-m MODEL_PATH` | Path to the model repository. |
| `--model-type MODEL_TYPE` | `-mt MODEL_TYPE` | Type of the model. |
| `--head-type HEAD_TYPE` | `-ht HEAD_TYPE` | Type of model head. |
| `--loss-type LOSS_TYPE` | `-lt LOSS_TYPE` | Type of the loss function. |
| `--api-key API_KEY` | `-a API_KEY` | API key for authorization. |
| `--host HOST` | `-H HOST` | Host to run the server on. |
| `--port PORT` | `-p PORT` | Port to run the server on. |
| `--reload, --no-reload` | - | Whether to reload the endpoint on detected code changes (requires workers to be set to 1). |
| `--workers WORKERS` | - | Number of endpoint workers (each will hold a model instance). |
| `--cuda, --no-cuda` | - | Flag to enable using a GPU to host the model. Flag is true by default. |
For example, to run lmarena-ai/p2l-7b-grk-02222025, which is a Qwen2 based "grk" model, which has head type `rk`, we would run:
```bash
python -m p2l.endpoint --model-path lmarena-ai/p2l-7b-grk-02222025 --model-type qwen2 --head-type rk --api-key <your-desired-api-key>
```
This code will host the model running on 1 worker and host 0.0.0.0 and port 10250 by default. Reload will be enabled meaning code changes will reload the endpoint. Note that by default the endpoint expects to load the model onto a GPU, however by specifying `--no-cuda` you can run this on CPU only, which may work for smaller P2L models.
Each P2L model has an associated model list, which specifices which model each index of the outputted coefficients corresponds to. Below is an example function to get this model list from the hosted endpoint:
```python
def get_p2l_endpoint_models(base_url: str, api_key: str) -> List[str]:
headers = {
"Content-Type": "application/json",
"api-key": api_key,
}
try:
response = requests.get(f"{base_url}/models", headers=headers)
response.raise_for_status()
result = response.json()
return result["models"]
except Exception as err:
print(f"An error occurred: {err}")
```
Below is an example python function to query the P2L endpoint:
```python
def query_p2l_endpoint(
prompt: list[str], base_url: str, api_key: str
) -> Dict[str, List]:
headers = {
"Content-Type": "application/json",
"api-key": api_key,
}
payload = {"prompt": prompt}
try:
response = requests.post(
f"{base_url}/predict", headers=headers, data=json.dumps(payload)
)
response.raise_for_status()
result = response.json()
return result
except Exception as err:
raise err
```
Note that the input is a list of strings. This is NOT for a batch of prompts, but rather for each turn in a coversation. For example, given a 2 turn conversation:
```
User: "hi!"
Assistant: "Hello"
User: "what's 1+1?"
```
The correct P2L input would be:
```python
["hi!", "what's 1+1?"]
```
## Serving an OpenAI Compatible Router
Serve an OpenAI compatible router with `python -m route.openai_server`. The available arguments are shown below.
| Option | Short Flag | Description |
|--------|-----------|-------------|
| `--help` | `-h` | Show this help message and exit. |
| `--config CONFIG` | `-c CONFIG` | Path to the configuration file. |
| `--router-type ROUTER_TYPE` | - | Type of the router to use. Available types are `bt-endpoint` and `grk-endpoint`.|
| `--router-model-name ROUTER_MODEL_NAME` | - | Name of the router model. |
| `--router-model-endpoint ROUTER_MODEL_ENDPOINT` | - | Endpoint URL for the router model. |
| `--router-api-key ROUTER_API_KEY` | - | API key for the router authentication. |
| `--cost-optimizer COST_OPTIMIZER` | - | Enable or configure cost optimization settings. Available types are `optimal-lp`, `simple-lp`, `strict`.|
| `--port PORT` | `-p PORT` | Port to run the server on. |
| `--host HOST` | - | Host to run the server on. |
| `--api-key API_KEY` | - | API key for authorization. |
| `--reload, --no-reload` | - | Whether to reload the endpoint on detected code changes (requires workers to be set to 1). |
| `--workers WORKERS` | - | Number of endpoint workers (each will hold a model instance). |
### Example: serving a Bradley-Terry based cost-optimal router
First, similar to above [above](#serving-p2l), we need to start serving a P2L model, this time Bradley-Terry based. To do this, let's run:
```bash
python -m p2l.endpoint --model-path lmarena-ai/p2l-7b-bt-01132025 --model-type qwen2 --head-type bt --api-key <your-desired-api-key>
```
Now, we need to configure a routing config file. This will specify the available models and inference details for the router.
For example, here is an example configuration that specifies Claude-3.5-Sonnet and GPT-4o:
```yaml
model_configs:
claude-3-5-sonnet-20241022:
api_key: <your-api-key>
base_url: null
cost: 9.3110239362
max_tokens: 8192
name: claude-3-5-sonnet-20241022
system_prompt: null
temp: 0.7
top_p: 0.7
type: anthropic
gpt-4o-2024-05-13:
api_key: <your-api-key>
base_url: null
cost: 12.3166873868
name: gpt-4o-2024-05-13
system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based
on the GPT-4 architecture.
Current date: 2025-01-06
Image input capabilities: Enabled
Personality: v2'
temp: 0.7
top_p: 1.0
type: openai
```
Notice how the system prompt, temperature, and top_p are defined. These replicate how the models are served on Chatbot Arena. P2L is trained with the expectation that the models are running on this configuration. Therefore, for the most reliable results, we recommend sticking to the configs shown in [`example_config.yaml`](./route/example_config.yaml), though alternatives should still function well.
Additionally, we allow for adjustment of the `cost` parameter. One natural choice is just cost per output token, however more accuracte cost estimates are better. For example, the costs in [`example_config.yaml`](./route/example_config.yaml) are calculated to be proportional to the formula `cost_per_output_token * average_output_tokens_per_response`.
Now, lets assume we put the above config content into `config.yaml`. To start the OpenAI compatible router we would run:
```bash
python -m route.openai_server --config config.yaml --router-type bt-endpoint --router-model-endpoint http://0.0.0.0:10250 --router-api-key <your-api-key> --cost-optimizer optimal-lp --api-key <your-endpoint-api-key>
```
Let's break down what this command means:
- `--router-type bt-endpoint`: we are using a Bradley-Terry based P2L model hosted on an endpoint.
- `--router-model-endpoint http://0.0.0.0:10250`: this is where the router endpoint is, generally the default address will be this if you are running the routing server on the same machine running the P2L endpoint.
- `--cost-optimizer optimal-lp`: we are using cost routing using the optimal linear program detailed in Theorem 1 of the paper.
>**Note**: `optimal-lp` is only compatible with BT models, and `simple-lp` is only compatible with grounded RK (sometimes specified as bag) models.
### Example: serving a Grounded RK based simple cost router
P2L has a class of "Grounded RK" models. These models produces coefficents such that `0.0` represents the threshold for a "usable" answer. We can leverage this to cost route to maximize $P(\text{Not Bad})$... whatever that means exactly. Below we detail the steps to run this routing setup.
First, start up the P2L endpoint:
```bash
python -m p2l.endpoint --model-path lmarena-ai/p2l-7b-grk-02222025 --model-type qwen2 --head-type rk --api-key <your-desired-api-key>
```
Then start up the router server:
```bash
python -m router.openai_server --config config.yaml --router-type grk-endpoint --router-model-endpoint http://0.0.0.0:10250 --router-api-key <your-api-key> --cost-optimizer simple-lp --api-key <your-endpoint-api-key>
```
## Calling the OpenAI Compatible Router
As aptly named, the router server is OpenAI compatible. We can call it like any other OpenAI compatible model:
```python
from openai import OpenAI
client = OpenAI(
base_url: "<your_router_endpoint_url>/v1",
api_key: "<your_router_api_key>",
)
prompt = "what's 828913*1234?"
response = client.chat.completions.create(
model="-", # This field is actually not used
message=[{"role": "user", "content": prompt}],
stream=True, # Router is compatible with and without streaming.
)
# Notice no temperature, top_p, or system prompt is set.
# This allows the router to use the default provided by the config file.
# If you do pass in these fields, they will override the config.
```
If we want to specify a cost budget, we need to do the following:
```python
response = client.chat.completions.create(
model="-", # This field is actually not used
message=[{"role": "user", "content": prompt}],
stream=True, # Router is compatible with and without streaming.
extra_body={"cost": <desired_cost>}
)
```
## Training a P2L Model
This codebase also contains the training code for P2L models. To train a P2L model, first set up a training config. The [`training_configs`](./training_configs/) directory has many examples.
To train run, for example:
```bash
deepspeed --num_gpus=8 --module p2l.train --config training_configs/<your_config>.yaml --no-eval --save-steps 512
```
## Inferencing a P2L Model
To quickly inference on a dataset using P2L, run:
```bash
python -m p2l.eval --model <p2l_model_name> --dataset <hf_dataset_path> --head-type <head_type> --model-type <qwen2_or_llama> --batch-size 2
```
This will work on any dataset of single turn prompts under the column name `prompt`.
## AutoEval Suite
Our in-depth evaluation code can be run using `p2l.auto_evals`.
### Params
- **a. Model List Params**
1. Either provide `--model_repo`, which has a `model_list.json` file.
2. Or provide a local `--model_list_path` file.
- **b. Val Data**
1. **Data is in JSONL format**:
- Provide a local `--eval_path`.
- If no path is provided, the program will look for an `eval_outputs.jsonl` file in the `--model_repo` on HF.
2. **Data is in JSON format (checkpoint files)**:
- Provide a local `--checkpoint_path`.
- Or provide remote `--hf_checkpoint_repo` and `--hf_checkpoint_file`.
- **c. Output Directory**
1. Provide a local `--output_dir` or a remote `--hf_output_dir`.
2. Provide `--output_file_name`.
- **d. Train Data (Optional)**
- Provide `--hf_train_dataset` or a local `--train_path`.
- **e. Arena Data (Optional)**
- Provide a local `--arena_path` (CSV with model rankings).
- **f. Provide Model Info**
1. `--loss_type` (e.g., `bt`, `bt_tie`, `rk`).
2. `--model_type` (e.g., `p2l`, `marginal`, `arena`, `marginal-gt`).
3. `--categories`.
- **g. Provide Types of Metrics**
1. `--simple_metrics`, `--category_metrics`, `--rand_subset_metrics`, `--aggr_scale_subset_metrics`.
2. Use `--metrics_to_inc` to filter out which of the above metrics to include.
- **h. Random Subset Params**
1. `--rand_subset_sizes`: Specify subset sizes.
2. `--rand_num_samples`: Specify the number of samples per random subset size.
- **i. Aggregation Subset Params**
1. `--aggr_scale_subset_sizes`: Specify subset sizes.
2. `--aggr_scale_num_samples`: Specify the number of samples per random subset size.
3. `--aggr_scale_gt`: Specify whether to use `marginal-gt` or `arena` as ground truth for categories.
---
## Citation
```
@misc{frick2025prompttoleaderboard,
title={Prompt-to-Leaderboard},
author={Evan Frick and Connor Chen and Joseph Tennyson and Tianle Li and Wei-Lin Chiang and Anastasios N. Angelopoulos and Ion Stoica},
year={2025},
eprint={2502.14855},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2502.14855},
}
```
================================================
FILE: deepspeed/zero1.json
================================================
{
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 1,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": true,
"zero_optimization": {
"stage": 1,
"reduce_bucket_size": 5e8
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": "auto"
}
}
}
================================================
FILE: fast_lambda_setup.sh
================================================
sudo apt-get update -y
sudo apt-get install tmux -y
sudo apt-get install python3-dev -y
sudo apt-get install tmux libaio-dev libopenmpi-dev python3-mpi4py -y
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
uv venv .env --python 3.10
source .env/bin/activate
uv pip install wheel packaging
uv pip install -r train_requirements.txt
uv pip install flash-attn==2.5.9.post1 --no-build-isolation
================================================
FILE: fast_runpod_setup.sh
================================================
apt-get update -y
apt-get install tmux -y
apt-get install python3-dev -y
apt-get install tmux libaio-dev libopenmpi-dev python3-mpi4py -y
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
uv venv .env --python 3.10
source .env/bin/activate
uv pip install wheel packaging
uv pip install -r train_requirements.txt
uv pip install flash-attn==2.5.9.post1 --no-build-isolation
================================================
FILE: p2l/auto_eval_utils.py
================================================
from typing import Callable, Dict
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
from scipy.optimize import minimize
from scipy.stats import kendalltau, spearmanr
from model import (
registered_losses,
HeadOutputs,
registered_aggr_models,
registered_pairwise_losses,
)
registered_simple_metrics: Dict[str, Dict[str, Callable]] = {}
registered_aggr_metrics: Dict[str, Dict[str, Callable]] = {}
registered_helpers: Dict[str, Callable] = {}
def register_simple_metric(loss_type: str, metric: str):
def decorator(func: Callable):
if loss_type not in registered_simple_metrics:
registered_simple_metrics[loss_type] = {}
registered_simple_metrics[loss_type][metric] = func
return func
return decorator
def register_aggr_metric(loss_type: str, metric: str):
def decorator(func: Callable):
if loss_type not in registered_aggr_metrics:
registered_aggr_metrics[loss_type] = {}
registered_aggr_metrics[loss_type][metric] = func
return func
return decorator
def register_helper(loss_or_model_type: str, helper_func):
def decorator(func: Callable):
if loss_or_model_type not in registered_helpers:
registered_helpers[loss_or_model_type] = {}
registered_helpers[loss_or_model_type][helper_func] = func
return func
return decorator
@register_helper("p2l", "output_labels")
def output_labels_p2l(val_data: pd.DataFrame, **kwargs):
betas = torch.tensor(np.stack(val_data["betas"]), dtype=torch.float)
labels = torch.tensor(np.stack(val_data["labels"]))
etas = None
if "eta" in val_data.columns:
etas = torch.tensor(np.stack(val_data["eta"]), dtype=torch.float)
return HeadOutputs(coefs=betas, eta=etas), labels
def translate_coefs(coef, old_list, new_list):
old_list = old_list.tolist()
old_to_new = [old_list.index(model) for model in new_list]
betas_array = np.array(coef)
betas_array = betas_array[old_to_new]
return torch.tensor(betas_array)
@register_helper("marginal", "output_labels")
def output_labels_marginal(
val_data: pd.DataFrame,
train_data: pd.DataFrame,
model_list: np.array,
train_model_list: np.array,
loss_type: str,
**kwargs,
):
train_labels = torch.tensor(np.stack(train_data["labels"]))
coefs, eta = train_marginal(train_model_list, train_labels, loss_type)
coefs, eta = coefs[0], eta[0] if eta is not None else None
coefs = translate_coefs(coefs, train_model_list, model_list)
val_labels = torch.tensor(np.stack(val_data["labels"]))
coefs = coefs.expand(len(val_labels), -1)
eta = eta.expand(len(val_labels), -1) if eta is not None else None
return HeadOutputs(coefs=coefs, eta=eta), val_labels
@register_helper("marginal-gt", "output_labels")
def output_labels_marginal_gt(
val_data: pd.DataFrame, model_list: np.array, loss_type: str, **kwargs
):
val_labels = torch.tensor(np.stack(val_data["labels"]))
coefs, eta = train_marginal(model_list, val_labels, loss_type)
coefs = coefs.expand(len(val_labels), -1)
eta = eta.expand(len(val_labels), -1) if eta is not None else None
return HeadOutputs(coefs=coefs, eta=eta), val_labels
@register_helper("arena", "output_labels")
def output_labels_arena(
arena_rankings: torch.tensor, val_data: pd.DataFrame, loss_type: str, **kwargs
):
labels = torch.tensor(np.stack(val_data["labels"]))
# arena rankings is already filtered so it will be 1d tensor
betas = arena_rankings.expand(len(labels), -1)
etas = torch.ones(len(labels))
etas = etas.unsqueeze(-1)
# TODO: Cleanup
if loss_type == "bt" or loss_type == "bt-tie":
etas = None
return HeadOutputs(coefs=betas, eta=etas), labels
@register_helper("bag", "preprocess_data")
def preprocess_data_bag(data: pd.DataFrame, **kwargs):
condition = data["winner"] == "tie (bothbad)"
data.loc[condition, "labels"] = data.loc[condition, "labels"].apply(
lambda arr: arr[:2] + [2]
)
return data
@register_helper("bt", "preprocess_data")
@register_helper("bt-tie", "preprocess_data")
@register_helper("rk", "preprocess_data")
@register_helper("rk-reparam", "preprocess_data")
def preprocess_data(data: pd.DataFrame, **kwargs):
return data
@register_simple_metric("bt", "Loss")
@register_simple_metric("bt", "BCELoss")
@register_simple_metric("bt-tie", "Loss")
@register_simple_metric("rk", "Loss")
@register_simple_metric("rk-reparam", "Loss")
@register_simple_metric("bag", "Loss")
def loss(head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs):
loss_func = registered_losses.get(loss_type)
return loss_func(head_output=head_output, labels=labels).item()
@register_simple_metric("rk", "Tie_Loss")
@register_simple_metric("bag", "Tie_Loss")
def tie_loss(head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs):
loss_func = registered_losses.get("tie-" + loss_type)
return loss_func(head_output=head_output, labels=labels).item()
@register_simple_metric("bag", "Tie_bb_Loss")
def tie_bb_loss(
head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs
):
loss_func = registered_losses.get("tie-bb-" + loss_type)
return loss_func(head_output=head_output, labels=labels).item()
@register_aggr_metric("bt", "Aggr_Tie_Loss")
@register_aggr_metric("bt-tie", "Aggr_Tie_Loss")
@register_aggr_metric("rk", "Aggr_Tie_Loss")
@register_aggr_metric("rk-reparam", "Aggr_Tie_Loss")
@register_aggr_metric("bag", "Aggr_Tie_Loss")
def Aggr_Tie_Loss(
gt_output: HeadOutputs,
model_output: HeadOutputs,
loss_type: str,
labels: torch.tensor,
**kwargs,
):
return aggr_metric("Tie_Loss", loss_type, labels, gt_output, model_output)
@register_simple_metric("bt-tie", "BCELoss")
@register_simple_metric("rk", "BCELoss")
@register_simple_metric("rk-reparam", "BCELoss")
@register_simple_metric("bag", "BCELoss")
def BCE_loss(head_output: HeadOutputs, labels: torch.Tensor, **kwargs):
non_tie_index = torch.where(labels[:, -1] == 0)[0]
new_coefs = head_output.coefs[non_tie_index, :]
new_eta = head_output.eta[non_tie_index] if head_output.eta is not None else None
no_tie_output = HeadOutputs(coefs=new_coefs, eta=new_eta)
no_tie_labels = labels[non_tie_index, :]
return loss(no_tie_output, no_tie_labels, loss_type="bt")
def aggr_metric(metric_name, loss_type, labels, gt_output, model_output):
func = registered_simple_metrics[loss_type][metric_name]
gt = func(
labels=labels, head_output=expand_output(gt_output, labels), loss_type=loss_type
)
model = func(
labels=labels,
head_output=expand_output(model_output, labels),
loss_type=loss_type,
)
return {"ground-truth": round(gt, 4), "model-aggr": round(model, 4)}
@register_aggr_metric("bt", "Aggr_Loss")
@register_aggr_metric("bt-tie", "Aggr_Loss")
@register_aggr_metric("rk", "Aggr_Loss")
@register_aggr_metric("rk-reparam", "Aggr_Loss")
@register_aggr_metric("bag", "Aggr_Loss")
def Aggr_Loss(
gt_output: HeadOutputs,
model_output: HeadOutputs,
loss_type: str,
labels: torch.tensor,
**kwargs,
):
return aggr_metric("Loss", loss_type, labels, gt_output, model_output)
@register_aggr_metric("bt", "Aggr_BCELoss")
@register_aggr_metric("bt-tie", "Aggr_BCELoss")
@register_aggr_metric("rk", "Aggr_BCELoss")
@register_aggr_metric("rk-reparam", "Aggr_BCELoss")
@register_aggr_metric("bag", "Aggr_BCELoss")
def Aggr_BCE_Loss(
gt_output: HeadOutputs,
model_output: HeadOutputs,
loss_type: str,
labels: torch.tensor,
**kwargs,
):
return aggr_metric("BCELoss", loss_type, labels, gt_output, model_output)
def expand_output(output, labels):
coefs, eta = output.coefs, output.eta
new_coefs = coefs.expand(len(labels), -1)
if eta is not None:
eta = eta.expand(len(labels), -1)
return HeadOutputs(coefs=new_coefs, eta=eta)
@register_simple_metric("bt", "MSELoss")
def BT_mse(
head_output: HeadOutputs,
labels: torch.Tensor,
**kwargs,
):
coefs = head_output.coefs
paired_coefs = coefs.gather(dim=-1, index=labels).contiguous()
paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]
predicted_probs = torch.sigmoid(paired_delta_logit)
true_labels = torch.ones_like(predicted_probs)
mse = F.mse_loss(predicted_probs, true_labels)
return mse.mean().item()
@register_simple_metric("bt-tie", "MSELoss")
def BT_tie_mst(
head_output: HeadOutputs,
labels: torch.Tensor,
**kwargs,
):
coefs = head_output.coefs
model_idx = labels[:, :2]
paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()
paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]
p_w = torch.sigmoid(paired_delta_logit)
tie_ind = labels[:, -1]
# let label be 0.5 if there is tie
pred_probs = torch.where(tie_ind == 1, 0.5, p_w)
true_labels = torch.ones_like(pred_probs)
mse = F.mse_loss(pred_probs, true_labels)
return mse.mean().item()
@register_simple_metric("rk", "MSELoss")
@register_simple_metric("rk-reparam", "MSELoss")
def RK_mse(head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs):
probs_func = registered_helpers[loss_type]["probs"]
p_w, _, p_t = probs_func(head_output=head_output, labels=labels)
tie_ind = labels[:, -1]
# True label will always be win (since first index) unless a tie occurs
pred_probs = torch.where(tie_ind == 1, p_t, p_w)
true_labels = torch.ones_like(pred_probs)
mse = F.mse_loss(pred_probs, true_labels)
return mse.mean().item()
@register_simple_metric("bag", "MSELoss")
def bag_mse(head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs):
probs_func = registered_helpers[loss_type]["probs"]
p_w, _, p_t, p_t_bb = probs_func(head_output=head_output, labels=labels)
tie_ind = labels[:, -1].unsqueeze(-1)
P = torch.stack([p_w, p_t, p_t_bb], dim=-1)
pred_probs = P.gather(dim=-1, index=tie_ind).contiguous().squeeze(-1)
true_labels = torch.ones_like(pred_probs)
mse = F.mse_loss(pred_probs, true_labels)
return mse.mean().item()
@register_helper("rk-reparam", "probs")
def rk_reparam_probs(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
coefs = head_output.coefs
eta = head_output.eta
theta = (torch.exp(eta) + 1.000001).squeeze(-1)
winner_idx = labels[:, 0:1]
loser_idx = labels[:, 1:2]
beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()[:, 0]
beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()[:, 0]
pi_win = torch.exp(beta_win)
pi_lose = torch.exp(beta_lose)
p_win = pi_win / (pi_win + theta * pi_lose + 1.0)
p_lose = pi_lose / (pi_lose + theta * pi_win + 1.0)
p_tie = 1.0 - p_win - p_lose
return p_win, p_lose, p_tie
@register_helper("bag", "probs")
def bag_probs(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
coefs = head_output.coefs
eta = head_output.eta
theta = (torch.exp(eta) + 1.000001).squeeze(-1)
winner_idx = labels[:, 0:1]
loser_idx = labels[:, 1:2]
beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()[:, 0]
beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()[:, 0]
pi_win = torch.exp(beta_win)
pi_lose = torch.exp(beta_lose)
pi_gamma = 1.0
p_win = pi_win / (pi_win + theta * pi_lose + pi_gamma)
p_lose = pi_lose / (pi_lose + theta * pi_win + pi_gamma)
p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)
p_tie = 1.0 - p_win - p_lose - p_tie_bb
return p_win, p_lose, p_tie, p_tie_bb
@register_helper("rk", "probs")
def rk_probs(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
coefs = head_output.coefs
eta = rk_eta(head_output)
model_idx = labels[:, :2]
paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()
paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]
p_w = torch.sigmoid(paired_delta_logit - eta)
p_l = torch.sigmoid(-1 * paired_delta_logit - eta)
p_t = 1 - p_w - p_l
return p_w, p_l, p_t
@register_simple_metric("bt", "Accuracy")
def BT_accuracy(
head_output: HeadOutputs,
labels: torch.Tensor,
**kwargs,
):
coefs = head_output.coefs
paired_coefs = coefs.gather(dim=-1, index=labels).contiguous()
paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]
# winner would have positive difference
correct = (paired_delta_logit > 0).float()
return correct.mean().item()
@register_simple_metric("bt-tie", "Accuracy")
def BT_tie_accuracy(
head_output: HeadOutputs,
labels: torch.Tensor,
**kwargs,
):
coefs = head_output.coefs
paired_coefs = coefs.gather(dim=-1, index=labels).contiguous()
paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]
# winner would have positive difference
correct = (paired_delta_logit > 0).float()
tie_ind = labels[:, -1]
# we give ties half the accuracy
correct[tie_ind == 1] = 0.5
return correct.mean().item()
@register_simple_metric("rk", "Accuracy")
@register_simple_metric("rk-reparam", "Accuracy")
def RK_accuracy(
head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs
):
probs_func = registered_helpers[loss_type]["probs"]
p_w, p_l, p_t = probs_func(head_output=head_output, labels=labels)
pred_labels = torch.where(
p_w >= p_l, torch.where(p_w >= p_t, 1, 0.5), torch.where(p_l >= p_t, 0, 0.5)
)
tie_ind = labels[:, -1]
# tie if tie index, else winner (first index) predicted to win
true_labels = torch.where(tie_ind == 1, 0.5, 1)
correct = (pred_labels == true_labels).float()
return correct.mean().item()
@register_simple_metric("rk", "Tie_Accuracy")
def RK_tie_accuracy(
head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs
):
probs_func = registered_helpers[loss_type]["probs"]
p_w, p_l, p_t = probs_func(head_output=head_output, labels=labels)
p_nt = p_w + p_l
pred_tie = torch.where(p_t >= p_nt, 1, 0)
tie_ind = labels[:, -1]
correct = (pred_tie == tie_ind).float()
return correct.mean().item()
@register_simple_metric("bag", "Tie_Accuracy")
def bag_tie_accuracy(
head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs
):
probs_func = registered_helpers[loss_type]["probs"]
p_w, p_l, p_t, p_t_bb = probs_func(head_output=head_output, labels=labels)
p_nt = p_w + p_l
p_tie = p_t + p_t_bb
pred_tie = torch.where(p_nt >= p_tie, 0, 1)
tie_ind = torch.where(labels[:, -1] == 0, 0, 1)
correct = (pred_tie == tie_ind).float()
return correct.mean().item()
@register_simple_metric("bag", "Tie_bb_Accuracy")
def bag_tie_bb_accuracy(
head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs
):
probs_func = registered_helpers[loss_type]["probs"]
p_w, p_l, p_t, p_t_bb = probs_func(head_output=head_output, labels=labels)
p_nt_bb = p_w + p_l + p_t
pred_tie = torch.where(p_t_bb >= p_nt_bb, 1, 0)
tie_ind = torch.where(labels[:, -1] == 2, 1, 0)
correct = (pred_tie == tie_ind).float()
return correct.mean().item()
@register_aggr_metric("bt", "Aggr_Tie_Accuracy")
@register_aggr_metric("bt-tie", "Aggr_Tie_Accuracy")
@register_aggr_metric("rk", "Aggr_Tie_Accuracy")
@register_aggr_metric("rk-reparam", "Aggr_Tie_Accuracy")
@register_aggr_metric("bag", "Aggr_Tie_Accuracy")
def Aggr_Tie_accuracy(
gt_output: HeadOutputs,
model_output: HeadOutputs,
loss_type: str,
labels: torch.tensor,
**kwargs,
):
return aggr_metric("Tie_Accuracy", loss_type, labels, gt_output, model_output)
@register_aggr_metric("bt", "Aggr_Tie_Accuracy")
@register_aggr_metric("bt-tie", "Aggr_Tie_Accuracy")
@register_aggr_metric("rk", "Aggr_Tie_Accuracy")
@register_aggr_metric("rk-reparam", "Aggr_Tie_Accuracy")
@register_aggr_metric("bag", "Aggr_Tie_Accuracy")
def Aggr_Tie_accuracy(
gt_output: HeadOutputs,
model_output: HeadOutputs,
loss_type: str,
labels: torch.tensor,
**kwargs,
):
return aggr_metric("Tie_Accuracy", loss_type, labels, gt_output, model_output)
@register_aggr_metric("bt", "Aggr_Tie_bb_Accuracy")
@register_aggr_metric("bt-tie", "Aggr_Tie_bb_Accuracy")
@register_aggr_metric("rk", "Aggr_Tie_bb_Accuracy")
@register_aggr_metric("rk-reparam", "Aggr_Tie_bb_Accuracy")
@register_aggr_metric("bag", "Aggr_Tie_bb_Accuracy")
def Aggr_Tie_bb_accuracy(
gt_output: HeadOutputs,
model_output: HeadOutputs,
loss_type: str,
labels: torch.tensor,
**kwargs,
):
return aggr_metric("Tie_bb_Accuracy", loss_type, labels, gt_output, model_output)
@register_aggr_metric("bt", "Aggr_Tie_bb_Loss")
@register_aggr_metric("bt-tie", "Aggr_Tie_bb_Loss")
@register_aggr_metric("rk", "Aggr_Tie_bb_Loss")
@register_aggr_metric("rk-reparam", "Aggr_Tie_bb_Loss")
@register_aggr_metric("bag", "Aggr_Tie_bb_Loss")
def Aggr_Tie_bb_loss(
gt_output: HeadOutputs,
model_output: HeadOutputs,
loss_type: str,
labels: torch.tensor,
**kwargs,
):
return aggr_metric("Tie_bb_Loss", loss_type, labels, gt_output, model_output)
@register_simple_metric("rk-reparam", "Tie_Accuracy")
@register_simple_metric("bt", "Tie_Accuracy")
@register_simple_metric("bt-tie", "Tie_Accuracy")
@register_simple_metric("bt", "Tie_bb_Loss")
@register_simple_metric("rk-reparam", "Tie_bb_Loss")
@register_simple_metric("bt-tie", "Tie_bb_Loss")
@register_simple_metric("rk", "Tie_bb_Loss")
@register_simple_metric("bt", "Tie_Loss")
@register_simple_metric("bt-tie", "Tie_Loss")
@register_simple_metric("rk-reparam", "Tie_Loss")
@register_simple_metric("rk", "Tie_bb_Accuracy")
@register_simple_metric("rk-reparam", "Tie_bb_Accuracy")
@register_simple_metric("bt", "Tie_bb_Accuracy")
@register_simple_metric("bt-tie", "Tie_bb_Accuracy")
def not_implemented(
head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs
):
return -1 # not implemented
@register_simple_metric("bag", "Accuracy")
def bag_accuracy(
head_output: HeadOutputs, labels: torch.Tensor, loss_type: str, **kwargs
):
probs_func = registered_helpers[loss_type]["probs"]
p_w, p_l, p_t, p_t_bb = probs_func(head_output=head_output, labels=labels)
P = torch.stack([p_w, p_t, p_t_bb, p_l], dim=-1)
pred_labels = P.argmax(dim=-1)
tie_ind = labels[:, -1]
# let win be 0, tie be 1, tie_bb be 2. loss never predicted since winner_idx first
true_labels = tie_ind
correct = (pred_labels == true_labels).float()
return correct.mean().item()
@register_simple_metric("bt", "Mean-BT")
@register_simple_metric("bt-tie", "Mean-BT")
@register_simple_metric("rk", "Mean-BT")
@register_simple_metric("rk-reparam", "Mean-BT")
@register_simple_metric("bag", "Mean-BT")
def beta_mean(
head_output: HeadOutputs,
**kwargs,
):
betas = head_output.coefs
flat_betas = betas.flatten()
return torch.mean(flat_betas).item()
@register_simple_metric("bt", "Std-BT")
@register_simple_metric("bt-tie", "Std-BT")
@register_simple_metric("rk", "Std-BT")
@register_simple_metric("rk-reparam", "Std-BT")
@register_simple_metric("bag", "Std-BT")
def beta_std(
head_output: HeadOutputs,
**kwargs,
):
betas = head_output.coefs
flat_betas = betas.flatten()
return torch.std(flat_betas).item()
@register_simple_metric("bt", "Spread-BT")
@register_simple_metric("bt-tie", "Spread-BT")
@register_simple_metric("rk", "Spread-BT")
@register_simple_metric("rk-reparam", "Spread-BT")
@register_simple_metric("bag", "Spread-BT")
def beta_spread(
head_output: HeadOutputs,
**kwargs,
):
betas = head_output.coefs
flat_betas = betas.flatten()
return (torch.max(flat_betas) - torch.min(flat_betas)).item()
@register_simple_metric("bt", "Mean-Spread-BT")
@register_simple_metric("bt-tie", "Mean-Spread-BT")
@register_simple_metric("rk", "Mean-Spread-BT")
@register_simple_metric("rk-reparam", "Mean-Spread-BT")
@register_simple_metric("bag", "Mean-Spread-BT")
def beta_mean_spread(
head_output: HeadOutputs,
**kwargs,
):
betas = head_output.coefs
max_min_per_prompt = (
torch.max(betas, dim=-1).values - torch.min(betas, dim=-1).values
)
return torch.mean(max_min_per_prompt).item()
@register_simple_metric("bt", "Mean-IQR-BT")
@register_simple_metric("bt-tie", "Mean-IQR-BT")
@register_simple_metric("rk", "Mean-IQR-BT")
@register_simple_metric("rk-reparam", "Mean-IQR-BT")
@register_simple_metric("bag", "Mean-IQR-BT")
def beta_mean_iqr(
head_output: HeadOutputs,
**kwargs,
):
betas = head_output.coefs
iqr_per_prompt = torch.quantile(betas, 0.75, dim=-1) - torch.quantile(
betas, 0.25, dim=-1
)
return torch.mean(iqr_per_prompt).item()
@register_simple_metric("bt", "Mean-Std-BT")
@register_simple_metric("bt-tie", "Mean-Std-BT")
@register_simple_metric("rk", "Mean-Std-BT")
@register_simple_metric("rk-reparam", "Mean-Std-BT")
@register_simple_metric("bag", "Mean-Std-BT")
def beta_mean_std(
head_output: HeadOutputs,
**kwargs,
):
betas = head_output.coefs
std_per_prompt = torch.std(betas, dim=-1)
return torch.mean(std_per_prompt).item()
@register_helper("marginal-gt", "aggregrate")
def aggr_marginal_gt(
labels: torch.Tensor, model_list: torch.Tensor, loss_type: str, **kwargs
):
coefs, eta = train_marginal(model_list, labels, loss_type)
return HeadOutputs(coefs=coefs[0], eta=eta[0] if eta is not None else None)
@register_helper("p2l", "aggregrate")
def aggr_p2l(
head_output: HeadOutputs,
labels: torch.Tensor,
model_list: torch.Tensor,
loss_type: str,
**kwargs,
):
coefs, eta = train_aggr_prob(
model_list, head_output, labels, loss_type, is_batch=False
)
return HeadOutputs(coefs=coefs[0], eta=eta[0] if eta is not None else None)
@register_helper("p2l", "aggregrate-batch")
def aggr_p2l_batch(
head_output: HeadOutputs,
labels: torch.Tensor,
model_list: torch.Tensor,
loss_type: str,
**kwargs,
):
coefs_batch, eta_batch = train_aggr_prob(
model_list, head_output, labels, loss_type, is_batch=True
)
return [
HeadOutputs(
coefs=coefs_batch[i], eta=eta_batch[i] if eta_batch is not None else None
)
for i in range(coefs_batch.shape[0])
]
@register_helper("marginal-gt", "aggregrate-batch")
def aggr_p2l_batch(
head_output: HeadOutputs,
labels: torch.Tensor,
model_list: torch.Tensor,
loss_type: str,
**kwargs,
):
# TODO: Make faster if necessary
return [
aggr_marginal_gt(labels[i], model_list, loss_type) for i in range(len(labels))
]
@register_helper("marginal", "aggregrate")
def aggr_non_p2l(head_output: HeadOutputs, loss_type: str, **kwargs):
etas = head_output.eta
etas = etas[0, :] if etas is not None else None
return HeadOutputs(coefs=head_output.coefs[0, :], eta=etas)
@register_helper("arena", "aggregrate")
def aggr_non_p2l(
head_output: HeadOutputs = None, arena_rankings: torch.tensor = None, **kwargs
):
eta = torch.tensor([0])
if arena_rankings is not None:
return HeadOutputs(coefs=arena_rankings, eta=eta)
# arena just has the same betas repeated if not provided
return HeadOutputs(coefs=head_output.coefs[0, :], eta=eta)
def train_marginal(model_list, labels, loss_type, lr=1.0, tol=1e-9, max_epochs=50):
model_cls = registered_aggr_models[loss_type]
model = model_cls(len(model_list))
optimizer = optim.LBFGS(
model.parameters(),
lr=lr,
max_iter=max_epochs,
tolerance_grad=tol,
tolerance_change=tol,
)
loss_func = registered_losses[loss_type]
labels = (
labels.squeeze() if labels.dim() > 2 else labels
) # marginal doesn't use batching since one at a time
def closure():
optimizer.zero_grad()
coefs, eta = model()
coefs_expanded = coefs[0].expand(len(labels), -1)
eta_expanded = eta[0].expand(len(labels), -1) if eta is not None else None
head_output = HeadOutputs(coefs=coefs_expanded, eta=eta_expanded)
loss = loss_func(head_output=head_output, labels=labels)
loss.backward()
return loss
optimizer.step(closure)
true_coefs, true_eta = model()
return true_coefs.detach(), true_eta.detach() if true_eta is not None else None
def train_aggr_prob(
model_list,
head_outputs,
labels,
loss_type,
is_batch,
lr=1.0,
tol=1e-9,
max_epochs=50,
):
true_probs_func = registered_helpers[loss_type]["pairwise_probs"]
true_probs = true_probs_func(real_output=head_outputs)
# add a batch size of 1 since aggregration is done in batches (only necessary if data isn't in batch format)
if not is_batch:
true_probs = true_probs.unsqueeze(0)
batch_size = true_probs.shape[0]
model_cls = registered_aggr_models[loss_type]
model = model_cls(len(model_list), batch_size)
optimizer = optim.LBFGS(
model.parameters(),
lr=lr,
max_iter=max_epochs,
tolerance_grad=tol,
tolerance_change=tol,
)
loss_func = registered_pairwise_losses[loss_type]
count = 0
prev_loss = 0
def closure():
optimizer.zero_grad()
coefs, eta = model()
aggr_output = HeadOutputs(coefs=coefs, eta=eta)
loss = loss_func(
real_output=head_outputs,
aggregated_output=aggr_output,
true_probs=true_probs,
)
loss.backward()
nonlocal count
count += 1
if count == 49:
raise Warning("Batch training did not converge")
return loss
optimizer.step(closure)
true_coefs, true_eta = model()
return true_coefs.detach(), true_eta.detach() if true_eta is not None else None
def rk_eta(output):
if output.eta is None:
return None
BETA = 0.1
return torch.clamp(
torch.nn.functional.softplus(output.eta - 22.5, BETA).squeeze(-1), min=0.02
)
@register_helper("rk", "pairwise_probs")
def pairwise_RK_probs(real_output: HeadOutputs):
real_betas = real_output.coefs
real_eta = rk_eta(real_output)
real_eta = real_eta.unsqueeze(-1)
num_models = real_betas.shape[-1]
pair_indices = torch.tensor(
[(i, j) for i in range(num_models) for j in range(i + 1, num_models)],
dtype=torch.long,
)
# elipses allow for both batched/unbatched
beta_i_real = real_betas[..., pair_indices[:, 0]]
beta_j_real = real_betas[..., pair_indices[:, 1]]
true_probs_win = torch.sigmoid(beta_i_real - beta_j_real - real_eta)
true_probs_loss = torch.sigmoid(beta_j_real - beta_i_real - real_eta)
true_probs_tie = 1.0 - true_probs_win - true_probs_loss
true_probs = torch.stack((true_probs_win, true_probs_loss, true_probs_tie), dim=-1)
return true_probs
@register_helper("rk-reparam", "pairwise_probs")
def pairwise_RK_reparam_probs(real_output: HeadOutputs, **kwargs):
real_betas = real_output.coefs
real_theta = torch.exp(real_output.eta) + 1.000001
num_models = real_betas.shape[-1]
pair_indices = torch.tensor(
[(i, j) for i in range(num_models) for j in range(i + 1, num_models)],
dtype=torch.long,
)
beta_i_real = real_betas[..., pair_indices[:, 0]]
beta_j_real = real_betas[..., pair_indices[:, 1]]
pi_win = torch.exp(beta_i_real)
pi_lose = torch.exp(beta_j_real)
p_win = pi_win / (pi_win + real_theta * pi_lose + 1.0)
p_lose = pi_lose / (pi_lose + real_theta * pi_win + 1.0)
p_tie = 1.0 - p_win - p_lose
true_probs = torch.stack((p_win, p_lose, p_tie), dim=-1)
return true_probs
@register_helper("bag", "pairwise_probs")
def pairwise_bag_probs(real_output: HeadOutputs, **kwargs):
real_betas = real_output.coefs
real_theta = torch.exp(real_output.eta) + 1.000001
num_models = real_betas.shape[-1]
pair_indices = torch.tensor(
[(i, j) for i in range(num_models) for j in range(i + 1, num_models)],
dtype=torch.long,
)
beta_i_real = real_betas[..., pair_indices[:, 0]]
beta_j_real = real_betas[..., pair_indices[:, 1]]
pi_win = torch.exp(beta_i_real)
pi_lose = torch.exp(beta_j_real)
pi_gamma = 1.0
p_win = pi_win / (pi_win + real_theta * pi_lose + pi_gamma)
p_lose = pi_lose / (pi_lose + real_theta * pi_win + pi_gamma)
p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)
p_tie = 1.0 - p_win - p_lose - p_tie_bb
true_probs = torch.stack((p_win, p_lose, p_tie, p_tie_bb), dim=-1)
return true_probs
@register_helper("bt", "pairwise_probs")
@register_helper("bt-tie", "pairwise_probs")
def pairwise_BT_probs(real_output: HeadOutputs):
real_betas = real_output.coefs
num_models = real_betas.shape[-1]
pair_indices = torch.tensor(
[(i, j) for i in range(num_models) for j in range(i + 1, num_models)],
dtype=torch.long,
)
beta_i_real = real_betas[..., pair_indices[:, 0]]
beta_j_real = real_betas[..., pair_indices[:, 1]]
true_probs = torch.sigmoid(beta_i_real - beta_j_real)
return true_probs
# removes nan from tensor, indices will be shifted
def remove_beta_nan(beta1, beta2):
beta_mask = ~torch.isnan(beta1) & ~torch.isnan(beta2)
return beta1[beta_mask], beta2[beta_mask]
@register_aggr_metric("bt", "Leaderboard")
@register_aggr_metric("bt-tie", "Leaderboard")
@register_aggr_metric("rk", "Leaderboard")
@register_aggr_metric("rk-reparam", "Leaderboard")
@register_aggr_metric("bag", "Leaderboard")
def leaderboard(
gt_output: HeadOutputs, model_output: HeadOutputs, model_list: np.array, **kwargs
):
gt_lb = get_leaderboard(gt_output, model_list)
model_lb = get_leaderboard(model_output, model_list)
return {"ground-truth": list(gt_lb), "model-aggr": list(model_lb)}
def get_leaderboard(output, model_list):
coefs = output.coefs
sorted_indices = torch.argsort(coefs, descending=True)
sorted_model_names = [model_list[i] for i in sorted_indices]
sorted_betas = coefs[sorted_indices]
leaderboard = []
for i in range(len(sorted_model_names)):
beta = (
round(sorted_betas[i].item(), 4)
if not torch.isnan(sorted_betas[i])
else "nan"
)
cur_model = str(sorted_model_names[i]) + ": " + str(beta)
leaderboard.append(cur_model)
return np.array(leaderboard)
@register_aggr_metric("bt", "L1-Dist-Prob")
@register_aggr_metric("bt-tie", "L1-Dist-Prob")
def l1_dist_prob_bt(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):
beta1 = gt_output.coefs
beta2 = model_output.coefs
# if arena is one, there may be nan if model not present in that file
beta1, beta2 = remove_beta_nan(beta1, beta2)
diff_matrix1 = beta1.unsqueeze(1) - beta1.unsqueeze(0)
diff_matrix2 = beta2.unsqueeze(1) - beta2.unsqueeze(0)
prob_vec1 = torch.sigmoid(diff_matrix1).flatten()
prob_vec2 = torch.sigmoid(diff_matrix2).flatten()
return torch.abs(prob_vec2 - prob_vec1).mean().item()
@register_aggr_metric("rk-reparam", "L1-Dist-Prob")
@register_aggr_metric("rk", "L1-Dist-Prob")
def l1_dist_prob_rk(
gt_output: HeadOutputs, model_output: HeadOutputs, loss_type: str, **kwargs
):
eta1 = gt_output.eta
eta2 = model_output.eta
# need to both have eta
if eta1 is None or eta2 is None:
return l1_dist_prob_bt(gt_output, model_output)
pair_probs_func = registered_helpers[loss_type]["pairwise_probs"]
p_win1, p_lose1, p_tie1 = torch.unbind(pair_probs_func(gt_output), -1)
p_win2, p_lose2, p_tie2 = torch.unbind(pair_probs_func(model_output), -1)
win_diff = torch.abs(p_win1 - p_win2).mean().item()
lose_diff = torch.abs(p_lose1 - p_lose2).mean().item()
tie_diff = torch.abs(p_tie1 - p_tie2).mean().item()
return (win_diff + lose_diff + tie_diff) / 3
@register_aggr_metric("bag", "L1-Dist-Prob")
def l1_dist_prob_bag(
gt_output: HeadOutputs, model_output: HeadOutputs, loss_type: str, **kwargs
):
eta1 = gt_output.eta
eta2 = model_output.eta
# need to both have eta
if eta1 is None or eta2 is None:
return l1_dist_prob_bt(gt_output, model_output)
pair_probs_func = registered_helpers[loss_type]["pairwise_probs"]
p_win1, p_lose1, p_tie1, p_tie_bb1 = torch.unbind(pair_probs_func(gt_output), -1)
p_win2, p_lose2, p_tie2, p_tie_bb2 = torch.unbind(pair_probs_func(model_output), -1)
win_diff = torch.abs(p_win1 - p_win2).mean().item()
lose_diff = torch.abs(p_lose1 - p_lose2).mean().item()
tie_diff = torch.abs(p_tie1 - p_tie2).mean().item()
tie_bb_diff = torch.abs(p_tie_bb2 - p_tie_bb1).mean().item()
return (win_diff + lose_diff + tie_diff + tie_bb_diff) / 4
@register_aggr_metric("bt", "IQR-BT")
@register_aggr_metric("bt-tie", "IQR-BT")
@register_aggr_metric("rk", "IQR-BT")
@register_aggr_metric("rk-reparam", "IQR-BT")
@register_aggr_metric("bag", "IQR-BT")
def beta_iqr(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):
(
gt_coefs,
model_coefs,
) = (
gt_output.coefs,
model_output.coefs,
)
gt_iqr = (torch.quantile(gt_coefs, 0.75) - torch.quantile(gt_coefs, 0.25)).item()
model_iqr = (
torch.quantile(model_coefs, 0.75) - torch.quantile(model_coefs, 0.25)
).item()
return {"ground-truth": round(gt_iqr, 4), "model-aggr": round(model_iqr, 4)}
@register_aggr_metric("bt", "Std-BT")
@register_aggr_metric("bt-tie", "Std-BT")
@register_aggr_metric("rk", "Std-BT")
@register_aggr_metric("rk-reparam", "Std-BT")
@register_aggr_metric("bag", "Std-BT")
def beta_std_aggr(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):
gt_betas, model_betas = gt_output.coefs, model_output.coefs
gt_std, model_std = (
torch.std(gt_betas.flatten()).item(),
torch.std(model_betas.flatten()).item(),
)
return {"ground-truth": round(gt_std, 4), "model-aggr": round(model_std, 4)}
@register_aggr_metric("bt", "Spread-BT")
@register_aggr_metric("bt-tie", "Spread-BT")
@register_aggr_metric("rk", "Spread-BT")
@register_aggr_metric("rk-reparam", "Spread-BT")
@register_aggr_metric("bag", "Spread-BT")
def beta_spread_aggr(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):
gt_betas, model_betas = gt_output.coefs.flatten(), model_output.coefs.flatten()
gt_spread, model_spread = torch.max(gt_betas) - torch.min(gt_betas), torch.max(
model_betas
) - torch.min(model_betas)
return {
"ground-truth": round(gt_spread.item(), 4),
"model-aggr": round(model_spread.item(), 4),
}
@register_aggr_metric("bt", "Kendall-lbs")
@register_aggr_metric("bt-tie", "Kendall-lbs")
@register_aggr_metric("rk", "Kendall-lbs")
@register_aggr_metric("rk-reparam", "Kendall-lbs")
@register_aggr_metric("bag", "Kendall-lbs")
def kendall_lb(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):
gt_betas, model_betas = remove_beta_nan(gt_output.coefs, model_output.coefs)
gt_lb = gt_betas.numpy()
model_lb = model_betas.numpy()
return kendalltau(gt_lb, model_lb)[0]
@register_aggr_metric("bt", "Spearman-lbs")
@register_aggr_metric("bt-tie", "Spearman-lbs")
@register_aggr_metric("rk", "Spearman-lbs")
@register_aggr_metric("rk-reparam", "Spearman-lbs")
@register_aggr_metric("bag", "Spearman-lbs")
def spearman_lb(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):
gt_betas, model_betas = remove_beta_nan(gt_output.coefs, model_output.coefs)
gt_lb = gt_betas.numpy()
model_lb = model_betas.numpy()
return spearmanr(gt_lb, model_lb)[0]
def top_k_frac(gt_betas: torch.tensor, model_betas: torch.tensor, k: int):
gt_top_indices = set(torch.topk(gt_betas, k).indices.numpy())
model_top_indices = set(torch.topk(model_betas, k).indices.numpy())
common_indices = gt_top_indices & model_top_indices
return len(common_indices) / k
def top_k_displace(gt_betas: torch.tensor, model_betas: torch.tensor, k: int):
gt_top_indices = torch.topk(gt_betas, k).indices
model_ranks = torch.argsort(torch.argsort(model_betas, descending=True))
displacements = torch.abs(model_ranks[gt_top_indices] - torch.arange(k))
return displacements.float().mean().item()
@register_aggr_metric("bt", "Top-k-fraction")
@register_aggr_metric("bt-tie", "Top-k-fraction")
@register_aggr_metric("rk", "Top-k-fraction")
@register_aggr_metric("rk-reparam", "Top-k-fraction")
@register_aggr_metric("bag", "Top-k-fraction")
def top_k_frac_dict(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):
gt_betas, model_betas = remove_beta_nan(gt_output.coefs, model_output.coefs)
res = {}
for k in [1, 3, 5, 10]:
res[k] = round(top_k_frac(gt_betas, model_betas, k), 4)
return res
@register_aggr_metric("bt", "Top-k-displace")
@register_aggr_metric("bt-tie", "Top-k-displace")
@register_aggr_metric("rk", "Top-k-displace")
@register_aggr_metric("rk-reparam", "Top-k-displace")
@register_aggr_metric("bag", "Top-k-displace")
def top_k_dist_dict(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):
gt_betas, model_betas = remove_beta_nan(gt_output.coefs, model_output.coefs)
res = {}
for k in [1, 3, 5, 10]:
res[k] = round(top_k_displace(gt_betas, model_betas, k), 4)
return res
================================================
FILE: p2l/auto_evals.py
================================================
import argparse
import json
import os
import io
import warnings
import math
from tqdm import tqdm
import time
import copy
import torch
import pandas as pd
import numpy as np
from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download, upload_file, list_repo_files
from model import HeadOutputs
from auto_eval_utils import (
registered_simple_metrics,
registered_aggr_metrics,
registered_helpers,
)
def parse_model_list(hf_model, local_path):
if not hf_model and not local_path:
raise ValueError("Either model repo or local model list must be provided.")
model_list_path = local_path
# if no local path, try getting from model_repo
if not model_list_path:
model_list_path = hf_hub_download(
repo_id=hf_model, filename="model_list.json", repo_type="model"
)
model_list = pd.read_json(model_list_path, lines=False).iloc[:, 0].tolist()
return np.array(model_list)
def change_beta_model_list(df, old_list, new_list):
old_list = old_list.tolist()
old_to_new = [old_list.index(model) for model in new_list]
betas_array = np.array(df["betas"].to_list())
betas_array = betas_array[:, old_to_new]
return betas_array.tolist()
def parse_eval_output_data(
model_repo,
local_eval_path,
local_checkpoint_path,
hf_checkpoint_repo,
hf_checkpoint_file,
loss_type,
model_list,
remove_last_hidden_json,
):
ret_df, ret_model_list = None, None
if local_checkpoint_path or hf_checkpoint_repo:
path = local_checkpoint_path
if not path:
if not hf_checkpoint_file:
raise ValueError(
"Must provide checkpoint file along with checkpoint repo"
)
path = hf_hub_download(
repo_id=hf_checkpoint_repo,
filename=hf_checkpoint_file,
repo_type="dataset",
)
df = pd.read_json(path)
# caching json w/o last hidden layer
if remove_last_hidden_json and local_checkpoint_path:
if "last_hidden_state" in df.columns:
df = df.drop(columns=["last_hidden_state"])
df.to_json(local_checkpoint_path)
df = df.rename(columns={"coefs": "betas"})
# data is stored with nested lists for both etas and betas only in checkpoint data
# df['eta'] = np.array(df['eta'].to_list()).flatten()
df["eta"] = df["eta"].apply(lambda x: x[0] if isinstance(x, list) else x)
df["betas"] = df["betas"].apply(lambda x: x[0] if isinstance(x, list) else x)
val_model_list = get_model_list_from_df(df)
# only betas need to be adjusted since labels are correct
df["betas"] = change_beta_model_list(df, model_list, val_model_list)
ret_df, ret_model_list = df, val_model_list
elif local_eval_path:
ret_df, ret_model_list = pd.read_json(local_eval_path, lines=True), model_list
elif model_repo:
files = list_repo_files(repo_id=model_repo, repo_type="model")
if "eval_output.jsonl" not in files:
raise FileNotFoundError(
f"'eval_output.jsonl' not found in the hf repository'{model_repo}'."
)
path = hf_hub_download(
repo_id=model_repo, filename="eval_output.jsonl", repo_type="model"
)
ret_df, ret_model_list = pd.read_json(path, lines=True), model_list
else:
raise ValueError("need to provide path for eval output data")
preprocess_func = registered_helpers[loss_type]["preprocess_data"]
ret_df = preprocess_func(data=ret_df)
return ret_df, ret_model_list
def add_labels_to_data(data, loss_type, model_list):
if loss_type == "bt":
data = data[~data["winner"].isin(["tie", "tie (bothbad)"])]
def create_labels(row):
winner = row["winner"]
model_a = row["model_a"]
model_b = row["model_b"]
model_a_idx = np.where(model_list == model_a)[0][0]
model_b_idx = np.where(model_list == model_b)[0][0]
tie_bb_label = 2 if loss_type == "bag" else 1
if winner == "model_a":
return np.array([model_a_idx, model_b_idx, 0])
elif winner == "model_b":
return np.array([model_b_idx, model_a_idx, 0])
elif winner == "tie":
return np.array([model_a_idx, model_b_idx, 1])
else:
return np.array([model_a_idx, model_b_idx, tie_bb_label])
data["labels"] = data.apply(create_labels, axis=1)
return data
# only use if completely necessary
def get_model_list_from_df(df):
return np.array(sorted(pd.concat([df["model_a"], df["model_b"]]).unique()))
def parse_train_data(hf_data, local_path, loss_type, train_model_list):
if not hf_data and not local_path:
warnings.warn(
"No train data provided, marginal model type will not work if specified"
)
return
if local_path:
if local_path.endswith(".jsonl"):
data = pd.read_json(local_path, lines=True)
else:
data = load_from_disk(local_path)["train"].to_pandas()
else:
data = load_dataset(hf_data, split="train").to_pandas()
return add_labels_to_data(data, loss_type, train_model_list)
def parse_arena_data(path, initial_rating=1000, BASE=10, SCALE=400):
if not path:
warnings.warn("Ground truth arena data not passed in, some metrics not work")
return
df = pd.read_csv(path)
# removes to avoid duplicates since not every model has a style_controlled ranking
df = df[df["style_control"] == False]
# ELO to beta using what eval_p2l.ipynb used
df["beta"] = (df["rating"] - initial_rating) / (SCALE * math.log(BASE))
pivot = df.pivot(index="model_name", columns="category", values="beta").reindex(
model_list
)
if pivot.isnull().any().any():
missing_models = pivot[pivot.isnull().any(axis=1)].index.tolist()
warnings.warn("Model not included in arena leaderboard:" + str(missing_models))
category_to_betas = {
category: torch.tensor(pivot[category].values, dtype=torch.float)
for category in pivot.columns
}
return category_to_betas
# NOTE: Only accepts certain categories, needs to be manually added
def filter_battle_data(battles, category):
if battles is None:
return None
# expect category key by itself or key=value
key_val_pair = category.split("=")
key = key_val_pair[0]
val = key_val_pair[1] if len(key_val_pair) == 2 else True
val = bool(val) if val in ["True", "true", "False", "false"] else val
try:
# no filtering
if key == "all":
return battles
# no nesting
if key == "language" or key == "is_code":
return battles[battles[key] == val]
# nested ones need specific cases
if key == "math":
return battles[
battles["category_tag"].apply(lambda x: x["math_v0.1"]["math"])
]
if key == "complexity":
return battles[
battles["category_tag"].apply(
lambda x: x["criteria_v0.1"]["complexity"]
)
]
if key == "creative_writing":
return battles[
battles["category_tag"].apply(
lambda x: x["creative_writing_v0.1"]["creative_writing"]
)
]
if key == "hard":
return battles[
battles["category_tag"].apply(
lambda x: sum(x["criteria_v0.1"].values()) >= 6
)
]
# Category not found
return None
except:
return None
# NOTE: Only accepts certain categories, needs to be manually added
def get_arena_rankings(data, category):
if data is None:
return None
key_val_pair = category.split("=")
key = key_val_pair[0]
val = key_val_pair[1] if len(key_val_pair) == 2 else True
val = bool(val) if val in ["True", "true", "False", "false"] else val
try:
# no filtering
if key == "all":
return data["full"]
# no nesting
if key == "language":
return data[val.lower()]
if key == "is_code":
return data["coding"]
if key == "math":
return data["math"]
if key == "hard":
return data["hard_6"]
if key == "creative_writing":
return data["creative_writing"]
return None
except:
return None
def get_subset_prompts(output, labels, size):
num_prompts = output.coefs.shape[0]
sampled_indices = torch.randperm(num_prompts)[:size]
sampled_coefs = output.coefs[sampled_indices, :]
sampled_eta = None
if output.eta is not None:
sampled_eta = output.eta[sampled_indices]
sampled_labels = labels[sampled_indices, :]
sampled_output = HeadOutputs(coefs=sampled_coefs, eta=sampled_eta)
return sampled_output, sampled_labels
def get_subset_prompts_batch(output, labels, size, batch_size):
num_prompts, num_models = output.coefs.shape
sampled_indices = torch.randint(low=0, high=num_prompts, size=(batch_size, size))
sampled_coefs = output.coefs[sampled_indices]
sampled_eta = None
if output.eta is not None:
sampled_eta = output.eta[sampled_indices]
sampled_labels = labels[sampled_indices]
sampled_output = HeadOutputs(coefs=sampled_coefs, eta=sampled_eta)
return sampled_output, sampled_labels
def get_ith_output(output, i):
betas = output.coefs[i]
eta = output.eta[i] if output.eta is not None else None
return HeadOutputs(coefs=betas, eta=eta)
def save_output(results, local_dir, hf_dir, file_name):
if not local_dir and not hf_dir:
raise ValueError("Specify a directory for outputs.")
results["params"]["output_file_name"] = file_name
file_name += ".json"
if local_dir:
path = os.path.join(local_dir, file_name)
with open(path, "w") as file:
json.dump(results, file, indent=4, separators=(",", ": "))
if hf_dir:
output = json.dumps(results, indent=4, separators=(",", ": "))
tmp_file = io.BytesIO(output.encode("utf-8"))
upload_file(
path_or_fileobj=tmp_file,
path_in_repo=file_name,
repo_id=hf_dir,
repo_type="model",
)
def simple_metrics(metrics, output, labels, loss_type):
results = {}
for metric in tqdm(metrics, desc="Simple Metrics", unit="metrics"):
metric_dict = registered_simple_metrics[loss_type]
metric_func = metric_dict[metric]
metric_val = metric_func(head_output=output, labels=labels, loss_type=loss_type)
results[metric] = (
round(metric_val, 4) if isinstance(metric_val, float) else metric_val
)
return results
def category_metrics(
metrics,
output,
labels,
loss_type,
model_type,
model_list,
ground_truth,
arena_rankings,
):
results = {}
aggr_func_model = registered_helpers[model_type]["aggregrate"]
# our default ground truth is marginal-gt but we can switch to arena or add configurability if desired
aggr_func_gt = registered_helpers[ground_truth]["aggregrate"]
model_output = aggr_func_model(
head_output=output, labels=labels, model_list=model_list, loss_type=loss_type
)
gt_output = aggr_func_gt(
labels=labels,
model_list=model_list,
loss_type=loss_type,
arena_rankings=arena_rankings,
)
for metric in tqdm(metrics, desc="Category Metrics", unit="metric"):
metric_dict = registered_aggr_metrics[loss_type]
metric_func = metric_dict[metric]
metric_val = metric_func(
gt_output=gt_output,
model_output=model_output,
model_list=model_list,
loss_type=loss_type,
labels=labels,
)
results[metric] = (
round(metric_val, 4) if isinstance(metric_val, float) else metric_val
)
return results
def random_subset_metrics(
metrics,
output,
labels,
subset_sizes,
trials_per_subset,
loss_type,
model_type,
model_list,
):
results = {}
aggr_func_model = registered_helpers[model_type]["aggregrate"]
# our default ground truth is marginal-gt but we can switch to arena or add configurability if desired
aggr_func_gt = registered_helpers["marginal-gt"]["aggregrate"]
for idx, size in enumerate(subset_sizes):
size = int(size)
subset_results = {metric: 0 for metric in metrics}
for _ in tqdm(
range(trials_per_subset[idx]),
desc=f"Random Subset size {size}",
unit="trial",
):
sample_output, sample_labels = get_subset_prompts(output, labels, size)
model_output = aggr_func_model(
head_output=sample_output,
labels=sample_labels,
model_list=model_list,
loss_type=loss_type,
)
gt_output = aggr_func_gt(
labels=sample_labels, model_list=model_list, loss_type=loss_type
)
for metric in metrics:
metric_dict = registered_aggr_metrics[loss_type]
metric_func = metric_dict[metric]
metric_val = metric_func(
gt_output=gt_output,
model_output=model_output,
model_list=model_list,
loss_type=loss_type,
)
subset_results[metric] += metric_val
for metric in metrics:
subset_results[metric] = round(
subset_results[metric] / trials_per_subset, 4
)
results[size] = subset_results
return results
def aggr_scale_metrics(
metrics,
output,
labels,
subset_sizes,
trials_per_subset,
loss_type,
model_type,
model_list,
arena_rankings,
gt,
):
results = {}
aggr_func_model = registered_helpers[model_type]["aggregrate-batch"]
# our default ground truth is arena ranking but we can switch to arena or add configurability if desired
aggr_func_gt = registered_helpers[gt]["aggregrate"]
gt_output = aggr_func_gt(
labels=labels,
model_list=model_list,
loss_type=loss_type,
arena_rankings=arena_rankings,
)
# TODO: arbitray threshold to limit memory consumption for batching
# max_prompts_times_samples_squared = 2e4
for idx, size in enumerate(subset_sizes):
size = int(size)
num_samples = int(trials_per_subset[idx])
subset_results = {metric: 0 for metric in metrics}
# num_full_mini_batches = int(max(
# 1, (size * (num_samples ** 2)) // max_prompts_times_samples_squared
# ))
num_full_mini_batches = int(max(1, num_samples // 100))
mini_batch_size = num_samples // num_full_mini_batches
leftover = num_samples - (num_full_mini_batches * mini_batch_size)
with tqdm(total=num_samples, desc=f"Aggr Subset Size {size}") as pbar:
def run_mini_batch(batch_count):
sample_output, sample_labels = get_subset_prompts_batch(
output, labels, size, batch_count
)
batch_output = aggr_func_model(
head_output=sample_output,
labels=sample_labels,
model_list=model_list,
loss_type=loss_type,
)
for cur_output in batch_output:
for metric in metrics:
metric_dict = registered_aggr_metrics[loss_type]
metric_func = metric_dict[metric]
metric_val = metric_func(
gt_output=gt_output,
model_output=cur_output,
model_list=model_list,
loss_type=loss_type,
)
subset_results[metric] += metric_val
pbar.update(1)
for _ in range(num_full_mini_batches):
run_mini_batch(mini_batch_size)
if leftover > 0:
run_mini_batch(leftover)
for metric in metrics:
subset_results[metric] = round(
subset_results[metric] / float(trials_per_subset[idx]), 4
)
results[size] = subset_results
return results
def get_metrics(
val_data, train_data, arena_rankings, val_model_list, train_model_list, args
):
results = {}
to_inc = set(args.metrics_to_inc)
output_label_func = registered_helpers[args.model_type]["output_labels"]
output, labels = output_label_func(
val_data=val_data,
train_data=train_data,
arena_rankings=arena_rankings,
loss_type=args.loss_type,
model_list=val_model_list,
train_model_list=train_model_list,
)
if "simple" in to_inc:
simple_results = simple_metrics(
metrics=args.simple_metrics,
output=output,
labels=labels,
loss_type=args.loss_type,
)
results["simple_metrics"] = simple_results
if "category" in to_inc:
category_results = category_metrics(
metrics=args.category_metrics,
loss_type=args.loss_type,
model_type=args.model_type,
model_list=val_model_list,
output=output,
labels=labels,
ground_truth=args.ground_truth,
arena_rankings=arena_rankings,
)
results["category_metrics"] = category_results
if "random_subsets" in to_inc:
subset_results = random_subset_metrics(
metrics=args.rand_subset_metrics,
subset_sizes=args.rand_subset_sizes,
trials_per_subset=args.rand_num_samples,
loss_type=args.loss_type,
model_type=args.model_type,
model_list=val_model_list,
output=output,
labels=labels,
)
results["random_subsets"] = subset_results
if "aggr_scale" in to_inc:
scale_results = aggr_scale_metrics(
metrics=args.aggr_scale_metrics,
subset_sizes=args.aggr_scale_subset_sizes,
trials_per_subset=args.aggr_scale_num_samples,
loss_type=args.loss_type,
model_type=args.model_type,
model_list=val_model_list,
output=output,
labels=labels,
arena_rankings=arena_rankings,
gt=args.ground_truth,
)
results["aggr_scale"] = scale_results
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# model repo contains model list and potentially, eval data (eval_output.jsonl)
parser.add_argument("--model_repo", type=str, default=None)
parser.add_argument("--model_list_path", type=str, default=None)
# val data is either in model repo, local file, or remotely as checkpoint file
parser.add_argument("--eval_path", nargs="+", type=str, default=None)
parser.add_argument("--checkpoint_path", nargs="+", type=str, default=None)
parser.add_argument("--hf_checkpoint_repo", type=str, default=None)
parser.add_argument("--hf_checkpoint_file", nargs="+", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)
parser.add_argument("--hf_output_dir", type=str, default=None)
parser.add_argument(
"--output_file_name", type=str, nargs="+", default=["eval_metrics"]
)
parser.add_argument("--hf_train_dataset", type=str, default=None)
parser.add_argument("--train_path", type=str, default=None)
parser.add_argument("--arena_path", type=str, default=None)
parser.add_argument("--loss_type", type=str, default="bt", help="bt, bt_tie, rk")
parser.add_argument(
"--model_type", type=str, default="p2l", help="p2l, marginal, arena"
)
parser.add_argument(
"--categories",
nargs="*",
default=[
"all",
"creative_writing",
"math",
"language=Chinese",
"is_code",
"hard",
],
)
parser.add_argument(
"--simple_metrics",
nargs="*",
default=[
"Loss",
"BCELoss",
"MSELoss",
"Accuracy",
"Tie_Loss",
"Tie_Accuracy",
"Tie_bb_Accuracy",
"Tie_bb_Loss",
"Mean-BT",
"Std-BT",
"Spread-BT",
"Mean-Spread-BT",
"Mean-IQR-BT",
"Mean-Std-BT",
],
)
parser.add_argument("--train_checkpoints", nargs="+", type=int, default=[])
parser.add_argument("--checkpoint_size", type=int, default=0)
# gt is marginal on val
parser.add_argument(
"--category_metrics",
nargs="*",
default=[
"Leaderboard",
"Aggr_Loss",
"Aggr_BCELoss",
"Aggr_Tie_Loss",
"Aggr_Tie_Accuracy",
"Aggr_Tie_bb_Accuracy",
"Aggr_Tie_bb_Loss",
"L1-Dist-Prob",
"Spearman-lbs",
"Kendall-lbs",
"IQR-BT",
"Std-BT",
"Spread-BT",
"Top-k-fraction",
"Top-k-displace",
],
)
parser.add_argument(
"--rand_subset_sizes", nargs="*", default=[250, 500, 1000, 2000]
)
parser.add_argument("--rand_num_samples", nargs="*", default=[50, 20, 5, 3])
parser.add_argument(
"--rand_subset_metrics",
nargs="*",
default=["L1-Dist-Prob", "Spearman-lbs", "Kendall-lbs"],
)
# gt is arena leaderboard
parser.add_argument(
"--aggr_scale_subset_sizes",
nargs="*",
default=[1, 10, 25, 100, 250, 500, 1000, 2000],
)
parser.add_argument(
"--aggr_scale_num_samples",
nargs="*",
default=[500, 500, 500, 200, 100, 40, 10, 6],
)
parser.add_argument(
"--aggr_scale_metrics",
nargs="*",
default=["L1-Dist-Prob", "Spearman-lbs", "Kendall-lbs"],
)
parser.add_argument("--ground_truth", type=str, default="marginal-gt")
parser.add_argument(
"--metrics_to_inc",
nargs="*",
default=["simple", "category", "random_subsets", "aggr_scale"],
)
parser.add_argument("--remove_last_hidden_json", default=True)
args = parser.parse_args()
start_time = time.time()
for idx in range(len(args.output_file_name)):
results = {}
results["params"] = copy.deepcopy(vars(args))
train_model_list = parse_model_list(args.model_repo, args.model_list_path)
eval_path = args.eval_path[idx] if args.eval_path else None
checkpoint_path = args.checkpoint_path[idx] if args.checkpoint_path else None
hf_checkpoint_file = (
args.hf_checkpoint_file[idx] if args.hf_checkpoint_file else None
)
# make sure right params are dumped
results["params"]["eval_path"] = eval_path
results["params"]["checkpoint_path"] = checkpoint_path
results["params"]["hf_checkpoint_file"] = hf_checkpoint_file
val_data, val_model_list = parse_eval_output_data(
args.model_repo,
eval_path,
checkpoint_path,
args.hf_checkpoint_repo,
hf_checkpoint_file,
args.loss_type,
train_model_list,
args.remove_last_hidden_json,
)
train_data = parse_train_data(
args.hf_train_dataset, args.train_path, args.loss_type, train_model_list
)
arena_data = parse_arena_data(args.arena_path)
models = {}
for category in args.categories:
cat_val_data = filter_battle_data(val_data, category)
cat_train_data = filter_battle_data(train_data, category)
arena_rankings = get_arena_rankings(arena_data, category)
current_model = str(args.model_type) + "-" + category
models[current_model] = get_metrics(
cat_val_data,
cat_train_data,
arena_rankings,
val_model_list,
train_model_list,
args,
)
# merely for marginal train checkpointing
for checkpoint in args.train_checkpoints:
num_data = checkpoint * args.checkpoint_size
checkpoint_train_data = train_data.head(num_data)
cat_train_data = filter_battle_data(checkpoint_train_data, category)
models[current_model + f"-checkpoint-{checkpoint}"] = get_metrics(
cat_val_data,
cat_train_data,
arena_rankings,
val_model_list,
train_model_list,
args,
)
results["models"] = models
save_output(
results, args.output_dir, args.hf_output_dir, args.output_file_name[idx]
)
end_time = time.time()
total_time = end_time - start_time
minutes = int(total_time // 60)
seconds = int(total_time % 60)
print(f"\nTotal time taken: {minutes} minutes and {seconds} seconds")
================================================
FILE: p2l/dataset.py
================================================
from transformers import PreTrainedTokenizer
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
import torch
from typing import List
def get_model_list(dataset: Dataset):
model_a_values = dataset.unique("model_a")
model_b_values = dataset.unique("model_b")
model_list_with_repeats = []
for value in model_a_values:
model_list_with_repeats.append(value)
for value in model_b_values:
model_list_with_repeats.append(value)
model_set = set(model_list_with_repeats)
model_list = sorted(list(model_set))
return model_list
def get_dataset(path: str, split: str, from_disk=False):
if from_disk:
dataset = load_from_disk(path)
if isinstance(dataset, DatasetDict):
dataset = dataset[split]
return dataset
else:
return load_dataset(path, split=split)
def _translate_label(
labels: List[int], train_model_list: List[str], val_model_list: List[str]
) -> List[int]:
label_copy = labels[:]
label_copy[0] = train_model_list.index(val_model_list[labels[0]])
label_copy[1] = train_model_list.index(val_model_list[labels[1]])
return label_copy
def translate_val_data(
val_data: Dataset, train_model_list: List[str], val_model_list: List[str]
) -> Dataset:
# Validate val models
for val_model in val_model_list:
assert val_model in train_model_list, val_model
# Translate val dataset
val_data = val_data.map(
lambda labels: {
"labels": _translate_label(labels, train_model_list, val_model_list)
},
input_columns="labels",
num_proc=16,
)
return val_data
class DataCollator:
def __init__(self, tokenizer, max_length, weight=None, reweight_scale=None):
self.tokenizer: PreTrainedTokenizer = tokenizer
self.max_length: int = max_length
self.weight: bool = weight
self.reweight_scale: float = reweight_scale
self.first = True
def __call__(self, data):
prompts = []
for seq in data:
if isinstance(seq["prompt"], str):
prompts.append([{"role": "user", "content": seq["prompt"]}])
else:
prompts.append([{"role": "user", "content": turn} for turn in seq["prompt"]])
labels = torch.tensor([seq["labels"].tolist() for seq in data])
formatted_prompts = self.tokenizer.apply_chat_template(
prompts,
tokenize=False,
add_generation_prompt=False,
add_special_tokens=False,
)
# Scrub any instances of cls token from the data, otherwise model will error.
formatted_prompts = [
prompt.replace(self.tokenizer.cls_token, "<cls>")
for prompt in formatted_prompts
]
formatted_prompts = [
seq + self.tokenizer.cls_token for seq in formatted_prompts
]
if self.first:
print(formatted_prompts)
self.first = False
encoded = self.tokenizer(
formatted_prompts,
padding=True,
return_tensors="pt",
add_special_tokens=False,
truncation=True,
max_length=self.max_length,
)
out = {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"labels": labels,
}
if self.weight:
if "weight" in data[0]:
out["weights"] = torch.tensor([seq["weight"].tolist() for seq in data])
if self.reweight_scale:
out["weights"] *= self.reweight_scale
else:
out["weights"] = None
return out
================================================
FILE: p2l/endpoint.py
================================================
import argparse
import json
from typing import Dict, Tuple, List, Optional
import torch
import uvicorn
from fastapi import FastAPI, Header, HTTPException
from huggingface_hub import hf_hub_download
from pydantic import BaseModel
from transformers import (
AutoTokenizer,
TextClassificationPipeline,
pipeline,
PreTrainedModel,
)
from p2l.model import get_p2l_model, P2LOutputs
from contextlib import asynccontextmanager
import logging
logging.getLogger().setLevel(logging.DEBUG)
def parse_args():
parser = argparse.ArgumentParser(description="Run FastAPI with Uvicorn")
parser.add_argument(
"--model-path",
"-m",
type=str,
default="p2el/Qwen2.5-7B-Instruct-rk-full-train",
help="Path to the model repository",
)
parser.add_argument(
"--model-type",
"-mt",
type=str,
default="qwen2",
help="Type of the model",
)
parser.add_argument(
"--head-type",
"-ht",
type=str,
default="rk",
help="Type of model head",
)
parser.add_argument(
"--loss-type",
"-lt",
type=str,
default="rk",
help="Type of the loss function",
)
parser.add_argument(
"--api-key",
"-a",
type=str,
default="-",
help="API key for authorization",
)
parser.add_argument(
"--host",
"-H",
type=str,
default="0.0.0.0",
help="Host to run the server on",
)
parser.add_argument(
"--port",
"-p",
type=int,
default=10250,
help="Port to run the server on",
)
parser.add_argument(
"--reload",
action=argparse.BooleanOptionalAction,
default=True,
help="Whether to reload the endpoint on detected code change, needs workers to be 1.",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="Number of endpoint workers (will hold a model per worker).",
)
parser.add_argument(
"--cuda",
action=argparse.BooleanOptionalAction,
default=True,
help="Flag to enable using a GPU to host the model. Flag is true by default.",
)
args = parser.parse_args()
return args
@asynccontextmanager
async def lifespan(app: FastAPI):
args = parse_args()
model, tokenizer, model_list = load_model(
args.model_path,
args.model_type,
args.head_type,
args.loss_type,
)
pipe = pipeline(
task="text-classification",
model=model,
tokenizer=tokenizer,
device="cuda" if args.cuda else "cpu",
pipeline_class=P2LPipeline,
)
app.state.api_key = args.api_key
app.state.model_list = model_list
app.state.model = model
app.state.tokenizer = tokenizer
app.state.pipe = pipe
try:
yield
finally:
pass
# Initialize FastAPI app
app = FastAPI(lifespan=lifespan)
# Define the input data structure
class InputData(BaseModel):
prompt: list[str]
class OutputData(BaseModel):
coefs: List[float]
eta: Optional[float] = None
class ModelList(BaseModel):
models: List[str]
class P2LPipeline(TextClassificationPipeline):
def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, torch.Tensor]:
return_tensors = self.framework
inputs = inputs["prompt"]
messages = [{"role": "user", "content": p} for p in inputs]
formatted = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
add_special_tokens=False,
)
formatted = formatted + self.tokenizer.cls_token
logging.debug(f"Formatted input: {formatted}")
return self.tokenizer(
formatted,
return_tensors=return_tensors,
max_length=8192,
padding="longest",
truncation=True,
)
def postprocess(
self, model_outputs: P2LOutputs, function_to_apply=None, top_k=1, _legacy=True
):
model_outputs = P2LOutputs(model_outputs)
eta = model_outputs.eta
return OutputData(
coefs=model_outputs.coefs.cpu().float().tolist()[0],
eta=eta.cpu().float().item() if eta else None,
)
@app.post("/predict")
async def predict(input_data: InputData, api_key: str = Header(...)):
logging.debug(f"Received Request: {input_data}.")
if api_key != app.state.api_key:
raise HTTPException(status_code=403, detail="Unauthorized")
try:
pipe: P2LPipeline = app.state.pipe
logging.debug(f"Input Prompt: {input_data.prompt}")
output = pipe(inputs=input_data.model_dump())
logging.debug(f"Output: {output}")
return output
except Exception as e:
logging.debug(e)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
async def models(api_key: str = Header(...)):
logging.debug(f"Received Model List Request.")
if api_key != app.state.api_key:
raise HTTPException(status_code=403, detail="Unauthorized")
try:
return ModelList(
models=app.state.model_list,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def load_model(
model_name, model_type, head_type, loss_type
) -> Tuple[PreTrainedModel, AutoTokenizer, List[str]]:
# Download and load the model list
fname = hf_hub_download(
repo_id=model_name, filename="model_list.json", repo_type="model"
)
with open(fname) as fin:
model_list = json.load(fin)
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.truncation_side = "left"
tokenizer.padding_side = "right"
# Get the model class and load the model
model_cls = get_p2l_model(model_type, loss_type, head_type)
model = model_cls.from_pretrained(
model_name,
CLS_id=tokenizer.cls_token_id,
num_models=len(model_list),
torch_dtype=torch.bfloat16,
)
return model, tokenizer, model_list
if __name__ == "__main__":
args = parse_args()
uvicorn.run(
"p2l.endpoint:app",
port=args.port,
host=args.host,
reload=args.reload,
workers=args.workers,
)
================================================
FILE: p2l/eval.py
================================================
import argparse
from p2l.model import get_p2l_model, P2LOutputs
from transformers import pipeline, TextClassificationPipeline, AutoTokenizer
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import torch
from typing import Dict
import pandas as pd
import os
import json
from tqdm.auto import tqdm
from torch.utils.data import Dataset
from glob import glob
class P2LPipeline(TextClassificationPipeline):
def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, torch.Tensor]:
return_tensors = self.framework
messages = [{"role": "user", "content": inputs}]
formatted = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
add_special_tokens=False,
)
formatted = formatted + self.tokenizer.cls_token
return self.tokenizer(
formatted,
return_tensors=return_tensors,
max_length=8192,
padding="longest",
truncation=True,
)
def postprocess(
self, model_outputs: P2LOutputs, function_to_apply=None, top_k=1, _legacy=True
):
model_outputs = P2LOutputs(model_outputs)
eta = model_outputs.eta
gamma = model_outputs.gamma
return dict(
coefs=model_outputs.coefs.cpu().float().numpy(),
eta=eta.cpu().float().numpy() if eta else None,
gamma=gamma.cpu().float().numpy() if gamma else None,
last_hidden_state=model_outputs.last_hidden_state.cpu().float().numpy(),
)
class ListDataset(Dataset):
def __init__(self, original_list):
self.original_list = original_list
def __len__(self):
return len(self.original_list)
def __getitem__(self, i):
return self.original_list[i]
def main(args, local_file=None):
os.makedirs(args.output_dir, exist_ok=True)
dataset = load_dataset(args.dataset, split=args.dataset_split)
if local_file:
fname = os.path.join(local_file, "model_list.json")
else:
fname = hf_hub_download(
repo_id=args.model_path, filename="model_list.json", repo_type="model"
)
with open(fname) as fin:
model_list = json.load(fin)
model_cls = get_p2l_model(args.model_type, args.loss_type, args.head_type)
if local_file:
tokenizer = AutoTokenizer.from_pretrained(local_file, local_files_only=True)
model = model_cls.from_pretrained(
local_file,
CLS_id=tokenizer.cls_token_id,
num_models=len(model_list),
torch_dtype=torch.bfloat16,
local_files_only=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = model_cls.from_pretrained(
args.model_path,
CLS_id=tokenizer.cls_token_id,
num_models=len(model_list),
torch_dtype=torch.bfloat16,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipeline(
task="text-classification",
model=model,
tokenizer=tokenizer,
device=device,
pipeline_class=P2LPipeline,
)
prompts = ListDataset(dataset["prompt"])
with torch.no_grad():
outputs = [
out
for out in tqdm(
pipe(prompts, batch_size=args.batch_size), total=len(prompts)
)
]
df = dataset.to_pandas()
outputs_df = pd.DataFrame.from_records(outputs)
if args.drop_hidden:
outputs_df = outputs_df.drop("last_hidden_state", axis=1)
df = pd.concat((df, outputs_df), axis=1)
if local_file:
fname = local_file.split("/")[-1] + ".json"
else:
fname = args.model_path.split("/")[-1] + ".json"
fpath = os.path.join(args.output_dir, fname)
df.to_json(fpath, orient="records", indent=4, force_ascii=False)
if args.output_hf_path:
from datasets import Dataset
df = pd.read_json(fpath)
hf_dataset = Dataset.from_pandas(df)
hf_dataset.push_to_hub(args.output_hf_path, private=True)
print("Results pushed to hub!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path", "-m", type=str, default=None, help="Huggingface model path"
)
parser.add_argument(
"--training-output-dir", "-t", type=str, default=None
)
parser.add_argument(
"--dataset", "-d", type=str, required=True, help="Huggingface dataset path"
)
parser.add_argument("--output-hf-path", "-oh", type=str, default=None)
parser.add_argument(
"--dataset-split",
"-ds",
type=str,
default="train",
help="Huggingface dataset split",
)
parser.add_argument(
"--model-type",
"-mt",
type=str,
default="qwen2",
help="Model type (qwen2, llama, etc)",
)
parser.add_argument(
"--head-type",
"-ht",
type=str,
default="bt",
help="Head type (Bradely Terry, Rao-Kupper, etc)",
)
parser.add_argument(
"--loss-type",
"-lt",
type=str,
default="bt",
help="Loss type (Bradely Terry, Rao-Kupper, etc)",
)
parser.add_argument("--batch-size", "-bs", type=int, default=1, help="Batch size")
parser.add_argument("--output-dir", "-od", type=str, default="outputs")
parser.add_argument("--drop-hidden", action=argparse.BooleanOptionalAction, default=False)
args = parser.parse_args()
if args.training_output_dir:
for file in glob(os.path.join(args.training_output_dir, "*")):
main(args, file)
else:
main(args)
================================================
FILE: p2l/model.py
================================================
import torch
from transformers import (
Qwen2Model,
Qwen2PreTrainedModel,
LlamaModel,
LlamaPreTrainedModel,
PreTrainedModel,
AutoTokenizer,
)
from transformers.utils import ModelOutput
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Callable, Optional
registered_transformers: Dict[str, Tuple[PreTrainedModel, PreTrainedModel]] = {
"qwen2": (Qwen2PreTrainedModel, Qwen2Model),
"llama": (LlamaPreTrainedModel, LlamaModel),
}
registered_losses: Dict[str, Callable] = {}
registered_heads: Dict[str, nn.Module] = {}
registered_inits: Dict[str, Callable] = {}
registered_aggr_models: Dict[str, nn.Module] = {}
registered_pairwise_losses: Dict[str, Callable] = {}
def register_loss(name: str):
def decorator(func: Callable):
registered_losses[name] = func
return func
return decorator
def register_head(name: str):
def decorator(func: Callable):
registered_heads[name] = func
return func
return decorator
def register_init(name: str):
def decorator(func: Callable):
registered_inits[name] = func
return func
return decorator
def register_aggr_model(name: str):
def decorator(func: Callable):
registered_aggr_models[name] = func
return func
return decorator
def register_pairwise_loss(name: str):
def decorator(func: Callable):
registered_pairwise_losses[name] = func
return func
return decorator
def register_init(name: str):
def decorator(func: Callable):
registered_inits[name] = func
return func
return decorator
@dataclass
class HeadOutputs(ModelOutput):
coefs: torch.FloatTensor = None
eta: Optional[torch.FloatTensor] = None
gamma: Optional[torch.FloatTensor] = None
@dataclass
class P2LOutputs(ModelOutput):
coefs: torch.FloatTensor = None
eta: Optional[torch.FloatTensor] = None
gamma: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
@register_loss("bt")
def BT_loss(
head_output: HeadOutputs,
labels: torch.Tensor,
weights: torch.Tensor = None,
**kwargs,
):
# labels columns are in the form (winner_idx, loser_idx)
coefs = head_output.coefs
paired_coefs = coefs.gather(dim=-1, index=labels).contiguous()
paired_delta_logit = (
paired_coefs[:, 0] - paired_coefs[:, 1]
) # subtract winner bt from loser bt
neg_log_sigma = -F.logsigmoid(paired_delta_logit) # get neg log prob
if weights is not None:
neg_log_sigma = neg_log_sigma * weights
loss = neg_log_sigma.mean()
return loss
@register_loss("bt-tie")
def BT_tie_loss(
head_output: HeadOutputs,
labels: torch.Tensor,
weights: torch.Tensor = None,
**kwargs,
):
# labels columns are in the form (winner_idx, loser_idx, tie_indicator)
coefs = head_output.coefs
model_idx = labels[:, :2] # (batch_dim, 2)
tie_ind = labels[:, -1]
paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()
paired_delta_logit = (
paired_coefs[:, 0] - paired_coefs[:, 1]
) # subtract winner bt from loser bt
# computes bradley-terry loss where tie is half win and half loss
neg_log_sigma = -1 * torch.where(
tie_ind == 0,
F.logsigmoid(paired_delta_logit),
0.5
* (F.logsigmoid(paired_delta_logit) + F.logsigmoid(-1 * paired_delta_logit)),
)
if weights is not None:
neg_log_sigma = neg_log_sigma * weights
loss = neg_log_sigma.mean()
return loss
BETA = 0.1
@register_loss("rk")
def RK_Loss(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
# labels columns are in form (winner_idx, loser_idx, tie_indicator)
coefs = head_output.coefs
# eta = torch.exp(head_output.eta).squeeze(-1) # eta > 0
eta = torch.clamp(
torch.nn.functional.softplus(head_output.eta - 22.5, BETA).squeeze(-1), min=0.02
)
# eta = torch.abs(head_output.eta).squeeze(-1)
model_idx = labels[:, :2] # (batch_dim, 2)
paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()
paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]
# compute RK probabilities
p_w = torch.sigmoid(paired_delta_logit - eta)
p_l = torch.sigmoid(-1 * paired_delta_logit - eta)
p_t = 1 - p_w - p_l
# point-wise likelihood
A = torch.stack((p_w, p_t)) # (2, batch_dim)
tie_ind = labels[:, -1].unsqueeze(0) # (1, batch_dim)
p = A.take_along_dim(dim=0, indices=tie_ind)
# mathematically p_t < 1 always but bfloat smh
p = torch.clamp(p, min=1e-3)
# eps = 1e-10
loss = -torch.log(p)
if weights:
loss = loss * weights
loss = loss.mean()
return loss
@register_loss("rk-reparam")
def RK_Reparam_Loss(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
coefs = head_output.coefs
eta = head_output.eta
theta = torch.exp(eta) + 1.000001
winner_idx = labels[:, 0:1]
loser_idx = labels[:, 1:2]
beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()
beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()
pi_win = torch.exp(beta_win)
pi_lose = torch.exp(beta_lose)
p_win = pi_win / (pi_win + theta * pi_lose + 1.0)
p_lose = pi_lose / (pi_lose + theta * pi_win + 1.0)
p_tie = 1.0 - p_win - p_lose
assert p_win.shape == p_lose.shape == p_tie.shape
P = torch.hstack((p_win, p_tie))
tie_ind = labels[:, -1].unsqueeze(-1)
p = P.gather(dim=-1, index=tie_ind).contiguous()
p = torch.clamp(p, min=1e-6)
loss = -torch.log(p)
if weights:
loss = loss * weights
loss = loss.mean()
return loss
@register_loss("ba")
def BA_loss(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
# labels are (winner_idx, loser_idx, tie_indicator (0 for no tie, 1 for tie, 2 for tie both bad))
coefs = head_output.coefs
eta = head_output.eta
gamma = head_output.gamma
theta = torch.exp(eta) + 1.02
winner_idx = labels[:, 0:1]
loser_idx = labels[:, 1:2]
beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()
beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()
pi_win = torch.exp(beta_win)
pi_lose = torch.exp(beta_lose)
pi_gamma = torch.exp(gamma)
p_win = pi_win / (pi_win + theta * pi_lose + pi_gamma)
p_lose = pi_lose / (pi_lose + theta * pi_win + pi_gamma)
p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)
p_tie = 1.0 - p_win - p_lose - p_tie_bb
P = torch.hstack((p_win, p_tie, p_tie_bb))
tie_ind = labels[:, -1].unsqueeze(-1)
p = P.gather(dim=-1, index=tie_ind).contiguous()
p = torch.clamp(p, min=1e-2)
loss = -torch.log(p)
if weights:
loss = loss * weights
loss = loss.mean()
print("loss: ", loss.item())
return loss
@register_loss("bag")
@register_loss("grk")
def GRK_loss(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
# labels are (winner_idx, loser_idx, tie_indicator (0 for no tie, 1 for tie, 2 for tie both bad))
coefs = head_output.coefs.float()
eta = head_output.eta.float()
theta = torch.exp(eta) + 1.000001
winner_idx = labels[:, 0:1]
loser_idx = labels[:, 1:2]
beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()
beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()
pi_win = torch.exp(beta_win)
pi_lose = torch.exp(beta_lose)
pi_gamma = 1.0
p_win = pi_win / (pi_win + theta * pi_lose + pi_gamma)
p_lose = pi_lose / (pi_lose + theta * pi_win + pi_gamma)
p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)
p_tie = 1.0 - p_win - p_lose - p_tie_bb
assert p_win.shape == p_lose.shape == p_tie_bb.shape == p_tie.shape
P = torch.hstack((p_win, p_tie, p_tie_bb))
tie_ind = labels[:, -1].unsqueeze(-1)
p = P.gather(dim=-1, index=tie_ind).contiguous()
p = torch.clamp(p, min=1e-6)
loss = -torch.log(p)
if weights:
loss = loss * weights
loss = loss.mean()
# print("loss: ", loss.item())
return loss
@register_head("bt")
class BTHead(nn.Module):
def __init__(
self, input_dim, output_dim, linear_head_downsize_factor=None, **kwargs
) -> None:
super().__init__()
if linear_head_downsize_factor:
inner_dim = int(output_dim // linear_head_downsize_factor)
self.head = nn.Sequential(
nn.Linear(in_features=input_dim, out_features=inner_dim, bias=True),
nn.Linear(in_features=inner_dim, out_features=output_dim, bias=True),
)
else:
self.head = nn.Linear(
in_features=input_dim, out_features=output_dim, bias=True
)
def forward(self, last_hidden_dim: torch.Tensor):
coefs = self.head(last_hidden_dim)
return HeadOutputs(coefs=coefs)
@register_head("rk")
class RKHead(nn.Module):
def __init__(
self,
input_dim,
output_dim,
eta_dim=1,
linear_head_downsize_factor=None,
eta_downsize=False,
**kwargs,
) -> None:
super().__init__()
# If linear header downsize factor and eta downsize, then eta is calculated off of the downsized dim, not the hidden dim.
if linear_head_downsize_factor:
inner_dim = output_dim // linear_head_downsize_factor
share_layer = nn.Linear(
in_features=input_dim, out_features=inner_dim, bias=True
)
self.head = nn.Sequential(
share_layer,
nn.Linear(in_features=inner_dim, out_features=output_dim, bias=True),
)
if eta_downsize:
self.eta_head = nn.Sequential(
share_layer,
nn.Linear(in_features=inner_dim, out_features=eta_dim, bias=True),
)
else:
self.eta_head = nn.Linear(
in_features=output_dim, out_features=eta_dim, bias=True
)
else:
self.head = nn.Linear(
in_features=input_dim, out_features=output_dim, bias=True
)
self.eta_head = nn.Linear(
in_features=input_dim, out_features=eta_dim, bias=True
)
def forward(self, last_hidden_dim: torch.Tensor):
coefs = self.head(last_hidden_dim)
eta = self.eta_head(last_hidden_dim)
return HeadOutputs(coefs=coefs, eta=eta)
@register_head("ba")
class BAHead(nn.Module):
def __init__(
self,
input_dim,
output_dim,
linear_head_downsize_factor=None,
**kwargs,
) -> None:
super().__init__()
if linear_head_downsize_factor:
raise NotImplementedError("Sorry I didn't implement this.")
self.head = nn.Linear(in_features=input_dim, out_features=output_dim, bias=True)
self.eta_head = nn.Linear(in_features=input_dim, out_features=1, bias=True)
self.gamma_head = nn.Linear(in_features=input_dim, out_features=1, bias=True)
def forward(self, last_hidden_dim: torch.Tensor):
coefs = self.head(last_hidden_dim)
eta = self.eta_head(last_hidden_dim)
gamma = self.gamma_head(last_hidden_dim)
return HeadOutputs(coefs=coefs, eta=eta, gamma=gamma)
@register_init("reset_params")
def reset_params_init(module):
return module.reset_parameters()
@register_init("he_unif")
def he_unif_init(module):
return nn.init.kaiming_uniform_(module.weight, nonlinearity="sigmoid")
@register_init("xavier_unif")
def xavier_unif_init(module):
return nn.init.xavier_uniform_(module.weight)
@register_init("tiny_normal")
def tiny_normal_init(module):
return nn.init.kaiming_normal_(module.weight)
def get_p2l_model(
model_type: str, loss_type: str, head_type: str, init_type: str = "reset_params"
) -> PreTrainedModel:
pretrained_model_cls, model_cls = registered_transformers[model_type]
criterion = registered_losses[loss_type]
head_layer = registered_heads[head_type]
init_func = registered_inits[init_type]
class CustomPretrainedModel(pretrained_model_cls):
"""Defines the appropriate pretrained class for the given model name. This is done so that the value head init scheme is correct."""
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
init_func(module) # was reset params
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class P2LModel(CustomPretrainedModel):
def __init__(
self,
config,
CLS_id,
num_models,
linear_head_downsize_factor=None,
head_kwargs={},
**kwargs,
):
super().__init__(config)
self.num_models = num_models
self.cls_token_id = CLS_id
self.model = model_cls(config)
self.head = head_layer(
input_dim=config.hidden_size,
output_dim=self.num_models,
linear_head_downsize_factor=linear_head_downsize_factor,
**head_kwargs,
)
self.post_init()
def freeze_transformer(self):
for param in self.model.parameters():
param.requires_grad = False
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def forward(self, input_ids, attention_mask, labels=None, weights=None):
batch_size = input_ids.shape[0]
hidden_outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=False,
).last_hidden_state # (bs, num_token, embed_dim)
cls_mask = input_ids == self.cls_token_id
# double check this is getting the current CLS token
cls_hidden_dim = hidden_outputs[cls_mask]
assert (
cls_hidden_dim.shape[0] == batch_size
), f"input ids {input_ids.shape}, cls_mask {cls_mask.shape}, cls_logit {cls_hidden_dim.shape}"
head_output = self.head(cls_hidden_dim)
if labels is not None:
loss = criterion(head_output, labels, weights=weights)
outputs = P2LOutputs(
coefs=head_output.coefs,
last_hidden_state=cls_hidden_dim,
eta=head_output.eta,
gamma=head_output.gamma,
loss=loss,
)
else:
outputs = P2LOutputs(
coefs=head_output.coefs,
last_hidden_state=cls_hidden_dim,
eta=head_output.eta,
gamma=head_output.gamma,
)
return outputs
return P2LModel
def get_tokenizer(
tokenizer_name,
chat_template,
pad_token_if_none="<|pad|>",
cls_token_if_none="<|cls|>",
):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.truncation_side = "left"
tokenizer.padding_side = "right"
if chat_template:
tokenizer.chat_template = chat_template
if "pad_token" not in tokenizer.special_tokens_map:
tokenizer.add_special_tokens({"pad_token": pad_token_if_none})
if "cls_token" not in tokenizer.special_tokens_map:
tokenizer.add_special_tokens({"cls_token": cls_token_if_none})
return tokenizer
@register_aggr_model("bt")
@register_aggr_model("bt-tie")
class BTAggrModel(nn.Module):
def __init__(self, num_models, batch_size=1):
super().__init__()
self.coefs = nn.Parameter(
nn.init.constant_(torch.empty(batch_size, num_models), 0.5)
)
self.eta = None
def forward(self):
return self.coefs, self.eta
@register_aggr_model("rk")
@register_aggr_model("rk-reparam")
@register_aggr_model("bag")
@register_aggr_model("grk")
class RKAggrModel(nn.Module):
def __init__(self, num_models, batch_size=1):
super().__init__()
self.coefs = nn.Parameter(
nn.init.constant_(torch.empty(batch_size, num_models), 0.5)
)
self.eta = nn.Parameter(nn.init.constant_(torch.empty(batch_size, 1), 0.1))
def forward(self):
return self.coefs, self.eta
@register_pairwise_loss("bt")
@register_pairwise_loss("bt-tie")
def pairwise_batch_BT_loss(
real_output: HeadOutputs, aggregated_output: HeadOutputs, true_probs: torch.tensor
):
real_betas = real_output.coefs
aggregated_betas = aggregated_output.coefs
num_prompts, num_models = real_betas.shape[-2], real_betas.shape[-1]
pair_indices = torch.tensor(
[(i, j) for i in range(num_models) for j in range(i + 1, num_models)],
dtype=torch.long,
)
beta_i_agg = aggregated_betas[:, pair_indices[:, 0]]
beta_j_agg = aggregated_betas[:, pair_indices[:, 1]]
pred_probs = torch.sigmoid(beta_i_agg - beta_j_agg)
pred_probs_expanded = pred_probs.unsqueeze(1).expand(-1, num_prompts, -1)
eps = 1e-9
neg_log_prob = -(
true_probs * torch.log(pred_probs_expanded + eps)
+ (1 - true_probs) * torch.log(1 - pred_probs_expanded + eps)
)
batch_losses = neg_log_prob.mean(dim=(1, 2))
loss = batch_losses.mean()
return loss
# batched loss
@register_pairwise_loss("rk")
def pairwise_batch_RK_loss(
real_output: HeadOutputs, aggregated_output: HeadOutputs, true_probs: torch.tensor
):
real_betas = real_output.coefs
num_prompts, num_models = real_betas.shape[-2], real_betas.shape[-1]
aggregated_betas = aggregated_output.coefs
BETA = 0.1
aggregated_eta = torch.clamp(
torch.nn.functional.softplus(aggregated_output.eta - 22.5, BETA).squeeze(-1),
min=0.02,
)
pair_indices = torch.tensor(
[(i, j) for i in range(num_models) for j in range(i + 1, num_models)],
dtype=torch.long,
)
beta_i_agg = aggregated_betas[:, pair_indices[:, 0]]
beta_j_agg = aggregated_betas[:, pair_indices[:, 1]]
aggregated_eta = aggregated_eta.unsqueeze(-1)
pred_probs_win = torch.sigmoid(beta_i_agg - beta_j_agg - aggregated_eta)
pred_probs_loss = torch.sigmoid(beta_j_agg - beta_i_agg - aggregated_eta)
pred_probs_tie = 1 - pred_probs_win - pred_probs_loss
pred_probs = torch.stack((pred_probs_win, pred_probs_loss, pred_probs_tie), dim=-1)
pred_probs_expanded = pred_probs.unsqueeze(1).expand(-1, num_prompts, -1, -1)
eps = 1e-9
neg_log_prob = -torch.sum(true_probs * torch.log(pred_probs_expanded + eps), dim=-1)
batch_losses = neg_log_prob.mean(dim=(1, 2))
loss = batch_losses.mean()
return loss
# batched
@register_pairwise_loss("rk-reparam")
def pairwise_batch_RK_reparam_loss(
real_output: HeadOutputs,
aggregated_output: HeadOutputs,
true_probs: torch.tensor,
**kwargs,
):
real_betas = real_output.coefs
num_prompts, num_models = real_betas.shape[-2], real_betas.shape[-1]
aggregated_betas = aggregated_output.coefs
aggregrated_theta = torch.exp(aggregated_output.eta) + 1.000001
pair_indices = torch.tensor(
[(i, j) for i in range(num_models) for j in range(i + 1, num_models)],
dtype=torch.long,
)
beta_i_agg = aggregated_betas[:, pair_indices[:, 0]]
beta_j_agg = aggregated_betas[:, pair_indices[:, 1]]
pi_win = torch.exp(beta_i_agg)
pi_lose = torch.exp(beta_j_agg)
p_win = pi_win / (pi_win + aggregrated_theta * pi_lose + 1.0)
p_lose = pi_lose / (pi_lose + aggregrated_theta * pi_win + 1.0)
p_tie = 1.0 - p_win - p_lose
pred_probs = torch.stack((p_win, p_lose, p_tie), dim=-1)
pred_probs_expanded = pred_probs.unsqueeze(1).expand(-1, num_prompts, -1, -1)
eps = 1e-9
neg_log_prob = -torch.sum(true_probs * torch.log(pred_probs_expanded + eps), dim=-1)
batch_losses = neg_log_prob.mean(dim=(1, 2))
loss = batch_losses.mean()
return loss
def get_bag_probs(beta_win, beta_lose, gamma, theta):
pi_win = torch.exp(beta_win)
pi_lose = torch.exp(beta_lose)
pi_gamma = 1.0
p_win = pi_win / (pi_win + theta * pi_lose + pi_gamma)
p_lose = pi_lose / (pi_lose + theta * pi_win + pi_gamma)
p_tie_bb = pi_gamma / (pi_gamma + pi_win + pi_lose)
p_tie = 1.0 - p_win - p_lose - p_tie_bb
return torch.stack((p_win, p_lose, p_tie, p_tie_bb), dim=-1)
# batched
@register_pairwise_loss("bag")
@register_pairwise_loss("grk")
def pairwise_batch_bag_loss(
real_output: HeadOutputs,
aggregated_output: HeadOutputs,
true_probs: torch.tensor,
**kwargs,
):
real_betas = real_output.coefs
num_prompts, num_models = real_betas.shape[-2], real_betas.shape[-1]
aggregated_betas = aggregated_output.coefs
aggregrated_theta = torch.exp(aggregated_output.eta) + 1.000001
pair_indices = torch.tensor(
[(i, j) for i in range(num_models) for j in range(i + 1, num_models)],
dtype=torch.long,
)
beta_i_agg = aggregated_betas[:, pair_indices[:, 0]]
beta_j_agg = aggregated_betas[:, pair_indices[:, 1]]
pred_probs = get_bag_probs(beta_i_agg, beta_j_agg, 1.0, aggregrated_theta)
pred_probs_expanded = pred_probs.unsqueeze(1).expand(-1, num_prompts, -1, -1)
eps = 1e-9
neg_log_prob = -torch.sum(true_probs * torch.log(pred_probs_expanded + eps), dim=-1)
batch_losses = neg_log_prob.mean(dim=(1, 2))
loss = batch_losses.mean()
return loss
@register_loss("tie-rk")
def RK_Tie_Loss(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
coefs = head_output.coefs
eta = torch.clamp(
torch.nn.functional.softplus(head_output.eta - 22.5, BETA).squeeze(-1), min=0.02
)
model_idx = labels[:, :2]
paired_coefs = coefs.gather(dim=-1, index=model_idx).contiguous()
paired_delta_logit = paired_coefs[:, 0] - paired_coefs[:, 1]
p_w = torch.sigmoid(paired_delta_logit - eta)
p_l = torch.sigmoid(-1 * paired_delta_logit - eta)
p_t = 1 - p_w - p_l
p_not_t = p_w + p_l
p_t = p_t
A = torch.stack((p_not_t, p_t))
tie_ind = labels[:, -1].unsqueeze(0)
p = A.take_along_dim(dim=0, indices=tie_ind)
p = torch.clamp(p, min=1e-3)
loss = -torch.log(p)
if weights:
loss = loss * weights
loss = loss.mean()
return loss
@register_loss("tie-bag")
@register_loss("tie-grk")
def bag_tie_loss(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
coefs = head_output.coefs
eta = head_output.eta
theta = torch.exp(eta) + 1.000001
winner_idx = labels[:, 0:1]
loser_idx = labels[:, 1:2]
beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()
beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()
p_win, p_lose, p_tie, p_tie_bb = torch.unbind(
get_bag_probs(beta_win, beta_lose, 1.0, theta), dim=-1
)
P = torch.hstack((p_win + p_lose, p_tie + p_tie_bb))
tie_ind = labels[:, -1].unsqueeze(-1)
tie_ind = torch.where(tie_ind == 0, 0, 1) # segment into ties and not ties
p = P.gather(dim=-1, index=tie_ind).contiguous()
p = torch.clamp(p, min=1e-6)
loss = -torch.log(p)
if weights:
loss = loss * weights
loss = loss.mean()
return loss
@register_loss("tie-bb-bag")
@register_loss("tie-bb-grk")
def bag_tie_bb_loss(
head_output: HeadOutputs, labels: Dict, weights: torch.Tensor = None, **kwargs
):
coefs = head_output.coefs
eta = head_output.eta
theta = torch.exp(eta) + 1.000001
winner_idx = labels[:, 0:1]
loser_idx = labels[:, 1:2]
beta_win = coefs.gather(dim=-1, index=winner_idx).contiguous()
beta_lose = coefs.gather(dim=-1, index=loser_idx).contiguous()
p_win, p_lose, p_tie, p_tie_bb = torch.unbind(
get_bag_probs(beta_win, beta_lose, 1.0, theta), dim=-1
)
P = torch.hstack((p_win + p_lose + p_tie, p_tie_bb))
tie_ind = labels[:, -1].unsqueeze(-1)
tie_ind = torch.where(tie_ind == 2, 1, 0) # index should be 1 if tie-bb
p = P.gather(dim=-1, index=tie_ind).contiguous()
p = torch.clamp(p, min=1e-6)
loss = -torch.log(p)
if weights:
loss = loss * weights
loss = loss.mean()
return loss
================================================
FILE: p2l/train.py
================================================
import argparse
import os
import yaml
import json
import random
from transformers import Trainer, TrainingArguments, set_seed
from p2l.dataset import DataCollator, get_model_list, get_dataset, translate_val_data
from p2l.model import get_p2l_model, get_tokenizer
from torch.utils.data import Sampler
from typing import Optional
from huggingface_hub import HfApi
# Want control over data ordering, use no shuffle trainer.
class NoShuffleTrainer(Trainer):
def _get_train_sampler(self) -> Optional[Sampler]:
return None
def train_model(args):
with open(args.config, "r") as file:
config = yaml.safe_load(file)
learning_rate = config["learning_rate"]
# Microbatch size
batch_size = config["batch_size"]
# HF data path
train_data_path = config["train_data_path"]
val_data_path = config["val_data_path"]
output_dir = config["output_dir"]
pretrain_model_name = config["pretrain_model_name"]
# Prompts will be truncted to this length
max_length = config["max_length"]
gradient_accumulation_steps = config["gradient_accumulation_steps"]
# Deepspeed config choices can be found in the deepspeed directory
deepspeed_config_path = config["deepspeed_config_path"]
# Type of transformer, see model.py for options.
model_type = config["model_type"]
# Loss type (e.g, bt, rk), see model.py for options.
loss_type = config["loss_type"]
# The linear head type, see model.py for options.
head_type = config["head_type"]
# Epsilon value for Adam
adam_epsilon = config["adam_epsilon"]
# Optional
epochs = config.get("num_train_epochs", 1)
lr_scheduler = config.get("lr_schedule", "constant")
chat_template = config.get("chat_template", None)
# Downsize the rank of the classification head.
linear_head_downsize_factor = config.get("linear_head_downsize_factor", None)
# Whether to weight the loss. If this is true, it expects that the dataset has a "weight" column.
weighted_loss = config.get("weighted_loss", False)
# kwargs for the head init.
head_config = config.get("head_config", {})
# If the tokenizer/model does not already have a cls token, this will be used.
cls_token_if_none = config.get("cls_token_if_none", "<|cls|>")
# If the tokenizer/model does not already have a pad token, this will be used.
pad_token_if_none = config.get("pad_token_if_none", "<|pad|>")
# If using weighted loss, scalar reweight factor
reweight_scale = config.get("reweight_scale", None)
proj_name = config.get("proj_name", None)
init_type = config.get("init_type", "reset_params")
train_head_only = config.get("train_head_only", False)
load_train_data_from_disk = config.get("load_train_data_from_disk", False)
load_val_data_from_disk = config.get("load_val_data_from_disk", False)
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", -1))
os.makedirs(output_dir, exist_ok=True)
# define project name
if not proj_name:
proj_name = f"{pretrain_model_name.split('/')[1]}_lr{learning_rate}_bs{batch_size}_ep{epochs}"
print(f"project name: {proj_name}")
output_path = os.path.join(output_dir, proj_name)
if args.checkpoint:
resume_from_checkpoint = args.checkpoint
print("resuming from checkpoint")
else:
resume_from_checkpoint = False
if not resume_from_checkpoint:
version = 1
while os.path.exists(output_path):
output_path = output_path.replace(f"_{version - 1}", "")
output_path = output_path + f"_{version}"
version += 1
with open(deepspeed_config_path) as fin:
deepspeed_config = json.load(fin)
random.seed(42)
set_seed(42)
training_args = TrainingArguments(
output_dir=output_path,
report_to="wandb",
run_name=proj_name,
num_train_epochs=epochs,
gradient_accumulation_steps=gradient_accumulation_steps,
save_strategy="no" if args.save_steps == -1 else "steps",
save_steps=None if args.save_steps == -1 else args.save_steps,
save_only_model=True,
eval_strategy="no",
logging_strategy="steps",
logging_steps=1,
ddp_timeout=9999999,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
eval_accumulation_steps=1,
eval_steps=args.eval_steps,
lr_scheduler_type=lr_scheduler,
logging_dir="./logs",
fp16=False,
bf16=True,
learning_rate=learning_rate,
adam_epsilon=adam_epsilon,
load_best_model_at_end=False,
gradient_checkpointing=True,
do_train=True,
bf16_full_eval=True,
save_safetensors=True,
disable_tqdm=False,
remove_unused_columns=False,
deepspeed=deepspeed_config,
seed=42,
data_seed=42,
local_rank=LOCAL_RANK,
)
tokenizer = get_tokenizer(
pretrain_model_name,
chat_template,
pad_token_if_none=pad_token_if_none,
cls_token_if_none=cls_token_if_none,
)
data_collator = DataCollator(
tokenizer, max_length, weight=weighted_loss, reweight_scale=reweight_scale
)
train_data = get_dataset(
train_data_path, "train", from_disk=load_train_data_from_disk
)
if not args.no_eval:
val_data = get_dataset(val_data_path, "train", from_disk=load_val_data_from_disk)
# with training_args.main_process_first():
model_list = get_model_list(train_data)
if not args.no_eval:
val_model_list = get_model_list(val_data)
if model_list != val_model_list:
print("WARNING: Val model list is different, translating...")
val_data = translate_val_data(val_data, model_list, val_model_list)
if LOCAL_RANK <= 0:
# Document the configuration in the output path.
os.makedirs(output_path, exist_ok=False)
with open(os.path.join(output_path, "training_config.json"), "w") as fout:
json.dump(config, fout, indent=1)
# Save the model list so we know which models this model was trained on. The model list is ALWAYS sorted alphabetically.
with open(os.path.join(output_path, "model_list.json"), "w") as fout:
json.dump(model_list, fout, indent=1)
# Get the model class
model_cls = get_p2l_model(
model_type=model_type,
loss_type=loss_type,
head_type=head_type,
init_type=init_type,
)
if resume_from_checkpoint:
print(f"Loading model from checkpoint: {resume_from_checkpoint}")
model = model_cls.from_pretrained(
resume_from_checkpoint,
CLS_id=tokenizer.cls_token_id,
num_models=len(model_list),
linear_head_downsize_factor=linear_head_downsize_factor,
)
else:
model = model_cls.from_pretrained(
pretrain_model_name,
CLS_id=tokenizer.cls_token_id,
num_models=len(model_list),
linear_head_downsize_factor=linear_head_downsize_factor,
)
if model.config.vocab_size < len(tokenizer):
print("WARNING: Resizing Token Embedding")
model.resize_token_embeddings(len(tokenizer))
if train_head_only:
print("Freezing transformer, only training head.")
model.freeze_transformer()
trainer = NoShuffleTrainer(
model=model,
args=training_args,
train_dataset=train_data.with_format("torch"),
# eval_dataset=val_data.with_format("torch"),
data_collator=data_collator,
)
print("begin training")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
trainer.save_model(output_path)
tokenizer.save_pretrained(output_path)
print("saved model and tokenizer")
if not args.no_eval:
print("starting eval")
eval_results = trainer.predict(val_data.with_format("torch"))
eval_metrics = eval_results.metrics
eval_predictions = eval_results.predictions
print(f"Evaluation Results: {eval_metrics}")
val_set = val_data.add_column("betas", list(eval_predictions[0]))
if LOCAL_RANK <= 0:
with open(os.path.join(output_path, "eval_results.json"), "w") as fout:
json.dump(eval_metrics, fout, indent=1)
val_dir = os.path.join(output_path, "eval_output.jsonl")
val_set.to_json(val_dir)
print(f"saved merged eval results")
if LOCAL_RANK <= 0:
if args.push_to_hf:
api = HfApi()
repo_id = config.get("repo_id", f"p2el/{proj_name}")
assert not api.repo_exists(
repo_id=repo_id, repo_type="model"
), "repo already exists"
api.create_repo(repo_id=repo_id, private=True, repo_type="model")
api.upload_folder(
folder_path=output_path,
repo_id=repo_id,
repo_type="model",
)
print("pushed to hub")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Argument Parser")
parser.add_argument(
"--config", type=str, help="path to config file for model training"
)
parser.add_argument(
"--checkpoint",
type=str,
help="path to checkpoint directory to resume training from",
default=None,
)
parser.add_argument(
"--push-to-hf",
action="store_true",
help="True if push directly to huggingface",
)
parser.add_argument(
"--eval-steps", type=int, default=60, help="Number of steps between evaluation."
)
parser.add_argument(
"--local_rank", type=int, default=-1, help="Local rank passed by DeepSpeed"
)
parser.add_argument(
"--no-eval",
action="store_true",
help="If flagged eval will not end at end of training loop.",
)
parser.add_argument("--save-steps", type=int, default=-1)
args = parser.parse_args()
train_model(args)
================================================
FILE: probe_barrier.py
================================================
# probe_barrier.py
import os, sys, time, datetime, argparse
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def log(msg: str, rank: int):
"""timestamped, unbuffered print"""
print(f"[{rank}|{time.time():.3f}] {msg}", flush=True)
def worker(rank: int, world_size: int, backend: str):
# ─── mandatory NCCL housekeeping ────────────────────────────
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29501")
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
if backend == "nccl":
torch.cuda.set_device(rank) # 1 GPU per rank
# ────────────────────────────────────────────────────────────
dist.init_process_group(
backend = backend,
rank = rank,
world_size = world_size,
timeout = datetime.timedelta(seconds=30) # fail fast
)
log("reached barrier()", rank)
dist.barrier()
log("*** passed barrier()", rank)
# Try another collective just to be sure
tensor = torch.tensor([rank], device="cuda" if backend == "nccl" else "cpu")
dist.all_reduce(tensor)
log(f"all_reduce ok, value={tensor.item()}", rank)
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--nprocs", type=int, default=2)
parser.add_argument("--backend", choices=["gloo", "nccl"], default="gloo")
args = parser.parse_args()
mp.spawn(
worker,
args=(args.nprocs, args.backend),
nprocs=args.nprocs,
join=True
)
if __name__ == "__main__":
# Completely unbuffered stdout/stderr
os.environ["PYTHONUNBUFFERED"] = "1"
main()
================================================
FILE: route/chat.py
================================================
from typing import List, Dict, Iterator, Tuple
import openai.resources
from abc import ABC, abstractmethod
import openai
from openai import OpenAI
import anthropic
from route.utils import get_registry_decorator
import time
from route.datatypes import (
Roles,
ChatMessage,
ChatCompletionResponse,
Choice,
ChatMessageDelta,
ChoiceDelta,
ChatCompletionResponseChunk,
RouterOutput,
ModelConfig,
)
import logging
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion import ChatCompletion
from anthropic.lib.streaming import MessageStream
from anthropic.types.message_start_event import MessageStartEvent
from uuid import uuid4
class BaseChatHandler(ABC):
@staticmethod
@abstractmethod
def _create_client(model_config: ModelConfig):
pass
@staticmethod
@abstractmethod
def _handle_system_prompt(
messages: List[ChatMessage], model_config: ModelConfig
) -> List[ChatMessage]:
pass
@staticmethod
@abstractmethod
def generate(
messages: List[ChatMessage],
router_output: RouterOutput,
temp: float | None,
top_p: float | None,
max_tokens: int | None,
) -> ChatCompletionResponse:
pass
@staticmethod
@abstractmethod
def generate_stream(
messages: List[ChatMessage],
router_output: RouterOutput,
temp: float | None,
top_p: float | None,
max_tokens: int | None,
) -> Iterator[ChatCompletionResponseChunk]:
pass
CHAT_HANDLERS: Dict[str, BaseChatHandler] = {}
register = get_registry_decorator(CHAT_HANDLERS)
@register("openai")
class OpenAIChatHandler(BaseChatHandler):
@staticmethod
def _create_client(model_config: ModelConfig):
api_key = model_config.get_api_key()
base_url = model_config.get_base_url()
if api_key or base_url:
client = openai.OpenAI(
base_url=base_url,
api_key=api_key,
)
else:
client = openai.OpenAI()
return client
@staticmethod
def _handle_system_prompt(
messages: List[ChatMessage], model_config: ModelConfig
) -> List[ChatMessage]:
system_prompt = model_config.get_system_prompt()
if system_prompt != None and messages[0].role != Roles.SYSTEM.value:
system_message = ChatMessage(
role=Roles.SYSTEM.value,
content=system_prompt,
)
messages = [system_message] + messages
return messages
@staticmethod
def _create_completion(
client: OpenAI,
model_config: ModelConfig,
messages: List[ChatMessage],
temp: float | None,
top_p: float | None,
max_tokens: int | None,
stream=False,
) -> ChatCompletion | Iterator[ChatCompletionChunk]:
completion = client.chat.completions.create(
model=model_config.get_name(),
messages=messages,
temperature=model_config.get_temp() if not temp else temp,
top_p=model_config.get_top_p() if not top_p else top_p,
max_tokens=(
model_config.get_max_tokens(default=openai.NOT_GIVEN)
if not max_tokens
else max_tokens
),
stream=stream,
)
return completion
@classmethod
def generate(
cls,
messages: List[ChatMessage],
router_output: RouterOutput,
temp: float | None,
top_p: float | None,
max_tokens: int | None,
) -> ChatCompletionResponse:
model_config = router_output.chosen_model_config
client = cls._create_client(model_config=model_config)
messages = cls._handle_system_prompt(
messages=messages, model_config=model_config
)
completion: ChatCompletion = cls._create_completion(
client=client,
model_config=model_config,
messages=messages,
temp=temp,
top_p=top_p,
max_tokens=max_tokens,
stream=False,
)
logging.info(f"{int(time.time())} Chosen Model Completion: {completion}")
chat_completion = ChatCompletionResponse(
id=str(completion.id),
object="chat.completion",
created=completion.created,
model=completion.model,
choices=[
Choice(
index=choice.index,
message=ChatMessage(
role=choice.message.role,
content=choice.message.content,
model=router_output.chosen_model_name,
),
finish_reason=choice.finish_reason,
)
for choice in completion.choices
],
usage=completion.usage,
router_outputs=router_output.model_scores,
)
return chat_completion
def _skip(chunk: ChatCompletionChunk) -> bool:
try:
content = chunk.choices[0].delta.content
return content == "" or content == None
except Exception as e:
return True
@classmethod
def generate_stream(
cls,
messages: List[ChatMessage],
router_output: RouterOutput,
temp: float | None,
top_p: float | None,
max_tokens: int | None,
) -> Iterator[ChatCompletionResponseChunk]:
model_config = router_output.chosen_model_config
client = cls._create_client(model_config=model_config)
messages = cls._handle_system_prompt(
messages=messages, model_config=model_config
)
chunks: Iterator[ChatCompletionChunk] = cls._create_completion(
client=client,
model_config=model_config,
messages=messages,
temp=temp,
top_p=top_p,
max_tokens=max_tokens,
stream=True,
)
first_chunk = True
logging_content = ""
for chunk in chunks:
if cls._skip(chunk):
continue
logging_content += chunk.choices[0].delta.content
out_chunk = ChatCompletionResponseChunk(
id=str(chunk.id),
object="chat.completion.chunk",
created=chunk.created,
model=chunk.model,
choices=[
ChoiceDelta(
index=choice.index,
delta=ChatMessageDelta(
role=choice.delta.role,
content=choice.delta.content,
model=router_output.chosen_model_name,
),
)
for choice in chunk.choices
],
usage=chunk.usage,
router_outputs=router_output.model_scores if first_chunk else None,
).model_dump_json()
yield f"data: {out_chunk}\n\n"
first_chunk = False
logging.info(
f"{int(time.time())} Chat Output (OpenAI Client): {logging_content}"
)
yield "data: [DONE]\n\n"
@register("openai-reasoning")
class OpenaiReasoningChatHandler(OpenAIChatHandler):
@staticmethod
def _create_completion(
client: OpenAI,
model_config: ModelConfig,
messages: List[ChatMessage],
temp: float | None,
top_p: float | None,
max_tokens: int | None,
stream=False,
) -> ChatCompletion | Iterator[ChatCompletionChunk]:
extra_field = model_config.get_extra_fields()
# No max tokens argument
completion = client.chat.completions.create(
model=model_config.get_name(), messages=messages, stream=stream, reasoning_effort=extra_field.get("reasoning_effort", openai.NOT_GIVEN),
)
return completion
@register("openai-o1")
class OpenaiO1ChatHandler(OpenaiReasoningChatHandler):
@classmethod
def generate_stream(
cls,
messages: List[ChatMessage],
router_output: RouterOutput,
temp: float | None,
top_p: float | None,
max_tokens: int | None,
) -> Iterator[ChatCompletionResponseChunk]:
model_config = router_output.chosen_model_config
client = cls._create_client(model_config=model_config)
messages = cls._handle_system_prompt(
messages=messages, model_config=model_config
)
chunk: ChatCompletion = cls._create_completion(
client=client,
model_config=model_config,
messages=messages,
temp=temp,
top_p=top_p,
max_tokens=max_tokens,
stream=False,
)
out_chunk = ChatCompletionResponseChunk(
id=str(chunk.id),
object="chat.completion.chunk",
created=chunk.created,
model=chunk.model,
choices=[
ChoiceDelta(
index=choice.index,
delta=ChatMessageDelta(
role=choice.message.role,
content=choice.message.content,
model=router_output.chosen_model_name,
),
)
for choice in chunk.choices
],
usage=chunk.usage,
router_outputs=router_output.model_scores,
).model_dump_json()
yield f"data: {out_chunk}\n\n"
logging.info(
f"{int(time.time())} Chat Output (OpenAI O1 Client): {chunk.choices[0].message.content}"
)
yield "data: [DONE]\n\n"
@register("anthropic")
class AnthropicChatHandler(BaseChatHandler):
@staticmethod
def _create_client(model_config: ModelConfig):
client = anthropic.Anthropic(api_key=model_config.get_api_key())
return client
@staticmethod
@abstractmethod
def _handle_system_prompt(
messages: List[ChatMessage], model_config: ModelConfig
) -> Tuple[List[ChatMessage], str | anthropic.NotGiven]:
system_message = model_config.get_system_prompt(default=anthropic.NOT_GIVEN)
if system_message == None:
system_message = anthropic.NOT_GIVEN
if messages[0].role == Roles.SYSTEM.value:
system_message = messages[0].content
messages = messages[1:]
return messages, system_message
@staticmethod
def generate(
messages: List[ChatMessage],
router_output: RouterOutput,
temp: float | None,
top_p: float | None,
max_tokens: int | None,
) -> ChatCompletionResponse:
model_config = router_output.chosen_model_config
client = AnthropicChatHandler._create_client(model_config=model_config)
messages, system_message = AnthropicChatHandler._handle_system_prompt(
messages=messages, model_config=model_config
)
completion = client.messages.create(
model=model_config.get_name(),
messages=messages,
stop_sequences=[anthropic.HUMAN_PROMPT],
temperature=model_config.get_temp() if not temp else temp,
top_p=model_config.get_top_p() if not top_p else top_p,
max_tokens=model_config.get_max_tokens() if not max_tokens else max_tokens,
system=system_message,
)
chat_completion = ChatCompletionResponse(
id=completion.id,
object="chat.completion",
created=int(time.time()),
model=completion.model,
choices=[
Choice(
index=i,
message=ChatMessage(
role=completion.role,
content=content.text,
model=router_output.chosen_model_name,
),
finish_reason=completion.stop_reason,
)
for i, content in enumerate(completion.content)
],
usage=completion.usage,
router_outputs=router_output.model_scores,
)
return chat_completion
@staticmethod
def generate_stream(
messages: List[ChatMessage],
router_output: RouterOutput,
temp: float | None,
top_p: float | None,
max_tokens: int | None,
) -> Iterator[ChatCompletionResponseChunk]:
model_config = router_output.chosen_model_config
client = AnthropicChatHandler._create_client(model_config=model_config)
messages, system_message = AnthropicChatHandler._handle_system_prompt(
messages=messages, model_config=model_config
)
with client.messages.stream(
model=model_config.get_name(),
messages=messages,
stop_sequences=[anthropic.HUMAN_PROMPT],
temperature=model_config.get_temp() if not temp else temp,
top_p=model_config.get_top_p() if not top_p else top_p,
max_tokens=model_config.get_max_tokens() if not max_tokens else max_tokens,
system=system_message,
) as _stream:
stream: MessageStream = _stream
# This contains the metadata
message_start: MessageStartEvent = next(stream)
resp_id = message_start.message.id
model = message_start.message.model
role = message_start.message.role
# Ignore this useless chunk.
next(stream)
first_chunk = True
logging_content = ""
for text in stream.text_stream:
logging_content += text
out_chunk = ChatCompletionResponseChunk(
id=resp_id,
created=int(time.time()),
model=model,
object="chat.completion.chunk",
choices=[
ChoiceDelta(
delta=ChatMessageDelta(
content=text,
role=role,
model=router_output.chosen_model_name,
),
index=0,
)
],
router_outputs=router_output.model_scores if first_chunk else None,
).model_dump_json()
yield f"data: {out_chunk}\n\n"
first_chunk = False
logging.info(
f"{int(time.time())} Chat Output (Anthropic Client): {logging_content}"
)
yield "data: [DONE]\n\n"
import google.generativeai as genai
from google.generativeai.types.generation_types import GenerateContentResponse
@register("gemini")
class GeminiChatHandler(BaseChatHandler):
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]
@staticmethod
def _create_client(model_config: ModelConfig):
api_key = model_config.get_api_key()
if api_key:
genai.configure(api_key=api_key)
@staticmethod
def _handle_system_prompt(
messages: List[ChatMessage], model_config: ModelConfig
) -> List[ChatMessage]:
system_prompt = model_config.get_system_prompt()
if system_prompt != None and messages[0].role != Roles.SYSTEM.value:
system_message = ChatMessage(
role=Roles.SYSTEM.value,
content=system_prompt,
)
messages = [system_message] + messages
return messages
@staticmethod
def _create_completion(
model_config: ModelConfig,
messages: List[ChatMessage],
temp: float | None,
top_p: float | None,
max_tokens: int | None,
stream=False,
) -> GenerateContentResponse | Iterator[GenerateContentResponse]:
generation_config = genai.GenerationConfig(
max_output_tokens=model_config.get_max_tokens(default=8192) if not max_tokens else max_tokens,
temperature=model_config.get_temp() if not temp else temp,
top_p=model_config.get_top_p() if not top_p else top_p,
top_k=model_config.get_top_k(),
)
history = []
system_prompt = None
for message in messages[:-1]:
if message.role == Roles.SYSTEM.value:
system_prompt = message.content
elif message.role == Roles.ASSISTANT.value:
history.append({"role": "model", "parts": message.content})
else:
history.append({"role": "user", "parts": message.content})
model = genai.GenerativeModel(
model_name=model_config.get_name(),
system_instruction=system_prompt,
generation_config=generation_config,
safety_settings=GeminiChatHandler.safety_settings,
)
chat_session = model.start_chat(history=history)
completion = chat_session.send_message(
content=messages[-1].content, stream=stream
)
return completion
@classmethod
def generate(
cls,
messages: List[ChatMessage],
router_output: RouterOutput,
temp: float | None,
top_p: float | None,
max_tokens: int | None,
) -> ChatCompletionResponse:
model_config = router_output.chosen_model_config
cls._create_client(model_config=model_config)
messages = cls._handle_system_prompt(
messages=messages, model_config=model_config
)
completion: GenerateContentResponse = cls._create_completion(
model_config=model_config,
messages=messages,
temp=temp,
top_p=top_p,
max_tokens=max_tokens,
stream=False,
)
logging.info(f"{int(time.time())} Chosen Model Completion: {completion}")
chat_completion = ChatCompletionResponse(
id=str(uuid4()),
object="chat.completion",
created=int(time.time()),
model=model_config.get_name(),
choices=[
Choice(
index=0,
message=ChatMessage(
role=Roles.ASSISTANT.value,
content=completion.text,
model=router_output.chosen_model_name,
),
finish_reason="STOP",
)
],
router_outputs=router_output.model_scores,
)
return chat_completion
@classmethod
def generate_stream(
cls,
messages: List[ChatMessage],
router_output: RouterOutput,
temp: float | None,
top_p: float | None,
max_tokens: int | None,
) -> Iterator[ChatCompletionResponseChunk]:
model_config = router_output.chosen_model_config
cls._create_client(model_config=model_config)
messages = cls._handle_system_prompt(
messages=messages, model_config=model_config
)
chunks: Iterator[GenerateContentResponse] = cls._create_completion(
model_config=model_config,
messages=messages,
temp=temp,
top_p=top_p,
max_tokens=max_tokens,
stream=True,
)
first_chunk = True
chat_id = str(uuid4())
logging_content = ""
for chunk in chunks:
logging_content += chunk.text
out_chunk = ChatCompletionResponseChunk(
id=chat_id,
object="chat.completion.chunk",
created=int(time.time()),
model=model_config.get_name(),
choices=[
ChoiceDelta(
index=0,
delta=ChatMessageDelta(
role=Roles.ASSISTANT.value,
content=chunk.text,
model=router_output.chosen_model_name,
),
)
],
router_outputs=router_output.model_scores if first_chunk else None,
).model_dump_json()
yield f"data: {out_chunk}\n\n"
first_chunk = False
logging.info(
f"{int(time.time())} Chat Output (Gemini Client): {logging_content}"
)
yield "data: [DONE]\n\n"
================================================
FILE: route/cost_optimizers.py
================================================
from abc import ABC, abstractmethod
from route.utils import get_registry_decorator
from typing import List, Dict
import numpy as np
import cvxpy as cp
from scipy.special import expit
class UnfulfillableException(Exception):
pass
class BaseCostOptimizer(ABC):
def __init__(self):
super().__init__()
@staticmethod
@abstractmethod
def select_model(
cost: float,
model_list: List[str],
model_costs: np.ndarray[float],
model_scores: np.ndarray[float],
**kwargs,
) -> str:
pass
@staticmethod
def select_max_score_model(
model_list: List[str], model_scores: np.ndarray[float]
) -> str:
max_idx = np.argmax(model_scores)
return model_list[max_idx]
COST_OPTIMIZERS: Dict[str, BaseCostOptimizer] = {}
register = get_registry_decorator(COST_OPTIMIZERS)
@register("strict")
class StrictCostOptimizer(BaseCostOptimizer):
def __init__(self):
super().__init__()
@staticmethod
def select_model(
cost: float | None,
model_list: List[str],
model_costs: np.ndarray[float],
model_scores: np.ndarray[float],
**kwargs,
) -> str:
if cost == None:
return StrictCostOptimizer.select_max_score_model(model_list, model_scores)
best_model: str | None = None
best_score = -float("inf")
for model, model_cost, model_score in zip(
model_list, model_costs, model_scores
):
if model_cost > cost:
continue
elif model_score > best_score:
best_model = model
best_score = model_score
if best_model is None:
raise UnfulfillableException(
f"Cost of {cost} impossible to fulfill with available models {model_list} with costs {model_costs}."
)
return best_model
@register("simple-lp")
class SimpleLPCostOptimizer(BaseCostOptimizer):
def __init__(self):
super().__init__()
@staticmethod
def select_model(
cost: float | None,
model_list: List[str],
model_costs: np.ndarray[float],
model_scores: np.ndarray[float],
**kwargs,
) -> str:
if cost == None:
return StrictCostOptimizer.select_max_score_model(model_list, model_scores)
p = cp.Variable(len(model_costs))
prob = cp.Problem(
cp.Maximize(cp.sum(model_scores @ p)),
[model_costs.T @ p <= cost, cp.sum(p) == 1, p >= 0],
)
status = prob.solve()
if status < 0.0:
raise UnfulfillableException(
f"Cost of {cost} impossible to fulfill with available models {model_list} with costs {model_costs}."
)
ps = np.clip(p.value, a_min=0.0, a_max=1.0)
ps = ps / ps.sum()
return np.random.choice(model_list, p=ps)
@register("optimal-lp")
class OptimalLPCostOptimizer(BaseCostOptimizer):
def __init__(self):
super().__init__()
@staticmethod
def select_model(
cost: float | None,
model_list: List[str],
model_costs: np.ndarray[float],
model_scores: np.ndarray[float],
opponent_scores: np.ndarray[float] = None,
opponent_distribution: np.ndarray[float] = None,
) -> str:
if cost == None:
return StrictCostOptimizer.select_max_score_model(model_list, model_scores)
W = OptimalLPCostOptimizer._construct_W(model_scores, opponent_scores)
Wq = W @ opponent_distribution
p = cp.Variable(len(model_costs))
prob = cp.Problem(
cp.Maximize(p @ Wq), [model_costs.T @ p <= cost, cp.sum(p) == 1, p >= 0]
)
status = prob.solve()
if status < 0.0:
raise UnfulfillableException(
f"Cost of {cost} impossible to fulfill with available models {model_list} with costs {model_costs}."
)
ps = np.clip(p.value, a_min=0.0, a_max=1.0)
ps = ps / ps.sum()
return np.random.choice(model_list, p=ps)
@staticmethod
def _construct_W(
router_model_scores: np.ndarray[float], opponent_model_scores: np.ndarray[float]
) -> np.ndarray[float]:
num_rows = router_model_scores.shape[-1]
num_cols = opponent_model_scores.shape[-1]
chosen = np.tile(router_model_scores, (num_cols, 1)).T
rejected = np.tile(opponent_model_scores, (num_rows, 1))
assert chosen.shape == rejected.shape, (chosen.shape, rejected.shape)
diff_matrix = chosen - rejected
W = expit(diff_matrix)
return W
================================================
FILE: route/datatypes.py
================================================
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from pydantic import BaseModel
from enum import Enum
class ModelConfig:
def __init__(self, config: Dict[str, Any]):
self.config = config
def get_name(self) -> str:
return self.config["name"]
def get_temp(self) -> float:
return self.config["temp"]
def get_top_p(self) -> float:
return self.config["top_p"]
def get_top_k(self, default=None) -> int:
return self.config.get("top_k", default)
def get_system_prompt(self, default=None) -> str | None | Any:
return self.config.get("system_prompt", default)
def get_api_key(self, default=None) -> str | None | Any:
return self.config.get("api_key", default)
def get_base_url(self, default=None) -> str | None | Any:
return self.config.get("base_url", default)
def get_type(self) -> str:
return self.config["type"]
def get_cost(self) -> float:
return self.config["cost"]
def get_max_tokens(self, default=None) -> int | None | Any:
return self.config.get("max_tokens", default)
def get_extra_fields(self) -> Dict:
return self.config.get("extra_fields", {}) # Maybe should be None...
def __repr__(self):
return repr(
dict(
name=self.get_name(),
type=self.get_type(),
cost=self.get_cost(),
)
)
class ModelConfigContainer:
def __init__(self, model_config_dicts: Dict[str, Dict[str, Any]]):
self.model_configs: Dict[str, ModelConfig] = dict(
(name, ModelConfig(config)) for name, config in model_config_dicts.items()
)
def get_model_config(self, model_name: str) -> ModelConfig:
return self.model_configs[model_name]
def list_models(self) -> List[str]:
return list(self.model_configs.keys())
def list_costs(self) -> List[float]:
costs: List[float] = []
for model_name in self.list_models():
model_config = self.get_model_config(model_name)
costs.append(model_config.get_cost())
return costs
def __repr__(self):
return repr(self.model_configs)
class Roles(Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class ChatMessage(BaseModel):
"""
Represents a single message in the conversation.
role: "system", "user", or "assistant"
content: the actual text
"""
role: str
content: str
model: Optional[str] = None
class ChatCompletionRequest(BaseModel):
"""
Request body for Chat Completion.
"""
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = 1
stream: Optional[bool] = False
stop: Optional[List[str]] = None
cost: Optional[float] = None
direct_model: Optional[str] = None
class Choice(BaseModel):
"""
Represents a single choice in the final response (non-streaming mode).
"""
index: int
message: ChatMessage
finish_reason: str
class ChatCompletionResponse(BaseModel):
"""
Response model for non-streaming mode.
"""
id: str
object: str
created: int
model: str
choices: List[Choice]
usage: Optional[BaseModel] = None
router_outputs: Optional[Dict[str, float]] = None
class ChatMessageDelta(BaseModel):
content: Optional[str] = None
role: Optional[str] = None
model: Optional[str] = None
class ChoiceDelta(BaseModel):
delta: ChatMessageDelta
finish_reason: Optional[str] = None
index: int
class ChatCompletionResponseChunk(BaseModel):
id: str
choices: List[ChoiceDelta]
created: int
model: str
object: str
usage: Optional[BaseModel] = None
router_outputs: Optional[Dict[str, float]] = None
@dataclass
class RouterOutput:
chosen_model_name: str
chosen_model_config: ModelConfig
model_scores: Dict[str, float] | None
================================================
FILE: route/example_config.yaml
================================================
model_configs:
athene-v2-chat:
api_key: <your-api-key>
base_url: http://38.142.9.21:10245/v1
cost: 0.8097264049
name: im-a-little-birdie
temp: 0.7
top_p: 1.0
type: openai
claude-3-5-haiku-20241022:
api_key: <your-api-key>
base_url: null
cost: 2.1765185825
max_tokens: 8192
name: claude-3-5-haiku-20241022
temp: 0.7
top_p: 0.7
type: anthropic
claude-3-5-sonnet-20240620:
api_key: <your-api-key>
base_url: null
cost: 9.4453041863
max_tokens: 8192
name: claude-3-5-sonnet-20240620
system_prompt: '<claude_info> The assistant is Claude, created by Anthropic. The
current date is 2025-01-06. Claude''s knowledge base was last updated on April
2024. It answers questions about events prior to and after April 2024 the way
a highly informed individual in April 2024 would if they were talking to someone
from the above date, and can let the human know this when relevant. Claude cannot
open URLs, links, or videos. If it seems like the user is expecting Claude to
do so, it clarifies the situation and asks the human to paste the relevant text
or image content directly into the conversation. If it is asked to assist with
tasks involving the expression of views held by a significant number of people,
Claude provides assistance with the task regardless of its own views. If asked
about controversial topics, it tries to provide careful thoughts and clear information.
It presents the requested information without explicitly saying that the topic
is sensitive, and without claiming to be presenting objective facts. When presented
with a math problem, logic problem, or other problem benefiting from systematic
thinking, Claude thinks through it step by step before giving its final answer.
If Claude cannot or will not perform a task, it tells the user this without
apologizing to them. It avoids starting its responses with "I''m sorry" or "I
apologize". If Claude is asked about a very obscure person, object, or topic,
i.e. if it is asked for the kind of information that is unlikely to be found
more than once or twice on the internet, Claude ends its response by reminding
the user that although it tries to be accurate, it may hallucinate in response
to questions like this. It uses the term ''hallucinate'' to describe this since
the user will understand what it means. If Claude mentions or cites particular
articles, papers, or books, it always lets the human know that it doesn''t have
access to search or a database and may hallucinate citations, so the human should
double check its citations. Claude is very smart and intellectually curious.
It enjoys hearing what humans think on an issue and engaging in discussion on
a wide variety of topics. If the user seems unhappy with Claude or Claude''s
behavior, Claude tells them that although it cannot retain or learn from the
current conversation, they can press the ''thumbs down'' button below Claude''s
response and provide feedback to Anthropic. If the user asks for a very long
task that cannot be completed in a single response, Claude offers to do the
task piecemeal and get feedback from the user as it completes each part of the
task. Claude uses markdown for code. Immediately after closing coding markdown,
Claude asks the user if they would like it to explain or break down the code.
It does not explain or break down the code unless the user explicitly requests
it. </claude_info>
<claude_3_family_info> This iteration of Claude is part of the Claude 3 model
family, which was released in 2024. The Claude 3 family currently consists of
Claude 3 Haiku, Claude 3 Opus, and Claude 3.5 Sonnet. Claude 3.5 Sonnet is the
most intelligent model. Claude 3 Opus excels at writing and complex tasks. Claude
3 Haiku is the fastest model for daily tasks. The version of Claude in this
chat is Claude 3.5 Sonnet. Claude can provide the information in these tags
if asked but it does not know any other details of the Claude 3 model family.
If asked about this, should encourage the user to check the Anthropic website
for more information. </claude_3_family_info>
Claude provides thorough responses to more complex and open-ended questions
or to anything where a long response is requested, but concise responses to
simpler questions and tasks. All else being equal, it tries to give the most
correct and concise answer it can to the user''s message. Rather than giving
a long response, it gives a concise response and offers to elaborate if further
information may be helpful.
Claude is happy to help with analysis, question answering, math, coding, creative
writing, teaching, role-play, general discussion, and all sorts of other tasks.
Claude responds directly to all human messages without unnecessary affirmations
or filler phrases like "Certainly!", "Of course!", "Absolutely!", "Great!",
"Sure!", etc. Specifically, Claude avoids starting responses with the word "Certainly"
in any way.
Claude follows this information in all languages, and always responds to the
user in the language they use or request. The information above is provided
to Claude by Anthropic. Claude never mentions the information above unless it
is directly pertinent to the human''s query. Claude is now being connected with
a human.
'
temp: 0.7
top_p: 0.7
type: anthropic
claude-3-5-sonnet-20241022:
api_key: <your-api-key>
base_url: null
cost: 9.3110239362
max_tokens: 8192
name: claude-3-5-sonnet-20241022
system_prompt: null
temp: 0.7
top_p: 0.7
type: anthropic
deepseek-v3:
api_key: <your-api-key>
base_url: https://api.deepseek.com
cost: 0.3002758331
name: deepseek-chat
temp: 1.5
top_p: 1.0
type: openai
gemini-1.5-flash-001:
api_key: <your-api-key>
cost: 0.4549682765
name: gemini-1.5-flash-001
temp: 0.7
top_p: 1.0
type: gemini
gemini-1.5-flash-002:
api_key: <your-api-key>
cost: 0.6330942997
name: gemini-1.5-flash-002
system_prompt: All questions should be answered comprehensively with details,
unless the user requests a concise response specifically. Respond in the same
language as the query.
temp: 0.7
top_p: 1.0
type: gemini
gemini-1.5-pro-001:
api_key: <your-api-key>
cost: 6.7456245955
name: gemini-1.5-pro-001
temp: 0.7
top_p: 0.7
type: gemini
gemini-1.5-pro-002:
api_key: <your-api-key>
cost: 9.6885059428
name: gemini-1.5-pro-002-test
system_prompt: All questions should be answered comprehensively with details,
unless the user requests a concise response specifically. Respond in the same
language as the query.
temp: 0.7
top_p: 1.0
type: gemini
gemini-2.0-flash-exp:
api_key: <your-api-key>
cost: 0.8978088229
name: gemini-test-14
temp: 1.0
top_k: 64
top_p: 0.95
type: gemini
gemini-2.0-flash-thinking-exp-1219:
api_key: <your-api-key>
cost: 0.4626591495
name: gemini-test-15
temp: 1.0
top_k: 64
top_p: 0.95
type: gemini
gemini-exp-1206:
api_key: <your-api-key>
cost: 6.7210154899
name: gemini-test-12
temp: 1.0
top_k: 64
top_p: 0.95
type: gemini
gemma-2-27b-it:
api_key: <your-api-key>
cost: 0.4732936067
name: gemma-2-27b-no-filter
temp: 0.7
top_p: 0.7
type: gemini
gemma-2-9b-it:
api_key: <your-api-key>
cost: 0.0873672873
name: gemma-2-9b-no-filter
temp: 0.7
top_p: 1.0
type: gemini
glm-4-plus:
api_key: <your-api-key>
base_url: https://open.bigmodel.cn/api/paas/v4
cost: 0.3175377664
name: glm-4-plus
temp: 0.7
top_p: 1.0
type: openai
gpt-4-1106-preview:
api_key: <your-api-key>
base_url: null
cost: 16.3622976323
name: gpt-4-1106-preview
system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based
on the GPT-4 architecture.
Current date: 2025-01-06
Image input capabilities: Enabled
Personality: v2'
temp: 0.7
top_p: 1.0
type: openai
gpt-4-turbo-2024-04-09:
api_key: <your-api-key>
base_url: null
cost: 17.4092447612
name: gpt-4-turbo-2024-04-09
system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based
on the GPT-4 architecture.
Current date: 2025-01-06
Image input capabilities: Enabled
Personality: v2'
temp: 0.7
top_p: 1.0
type: openai
gpt-4o-2024-05-13:
api_key: <your-api-key>
base_url: null
cost: 12.3166873868
name: gpt-4o-2024-05-13
system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based
on the GPT-4 architecture.
Current date: 2025-01-06
Image input capabilities: Enabled
Personality: v2'
temp: 0.7
top_p: 1.0
type: openai
gpt-4o-2024-08-06:
api_key: <your-api-key>
base_url: null
cost: 6.9944337124
name: gpt-4o-2024-08-06
system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based
on the GPT-4 architecture.
Current date: 2025-01-06
Image input capabilities: Enabled
Personality: v2'
temp: 0.7
top_p: 1.0
type: openai
gpt-4o-mini-2024-07-18:
api_key: <your-api-key>
base_url: null
cost: 0.563652953
name: gpt-4o-mini-2024-07-18
system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based
on the GPT-4 architecture.
Current date: 2025-01-06
Image input capabilities: Enabled
Personality: v2'
temp: 0.7
top_p: 1.0
type: openai
llama-3-70b-instruct:
api_key: <your-api-key>
base_url: https://api.together.xyz/v1
cost: 0.4186380435
name: meta-llama/Llama-3-70b-chat-hf
temp: 0.7
top_p: 1.0
type: openai
llama-3.1-405b-instruct-fp8:
api_key: <your-api-key>
base_url: https://api.fireworks.ai/inference/v1
cost: 2.4340008579
name: accounts/fireworks/models/llama-v3p1-405b-instruct
system_prompt: 'Cutting Knowledge Date: December 2023
Today Date: 06 Jan 2025'
temp: 0.6
top_p: 1.0
type: openai
llama-3.1-70b-instruct:
api_key: <your-api-key>
base_url: https://api.fireworks.ai/inference/v1
cost: 0.7204016024
name: accounts/fireworks/models/llama-v3p1-70b-instruct
system_prompt: "Cutting Knowledge Date: December 2023\nToday Date: 06 Jan 2025\n\
\nCarefully read the user prompt. Your responses are comprehensive and easy\
\ to understand. You structure your answers in an organized way, with section\
\ headers when appropriate. You use consistent formatting in your responses.\
\ You follow user instructions. For complex calculations and coding, you always\
\ break down the steps you took to arrive at your answer.\n\nPay extra attention\
\ to prompts in the following categories:\n * Non-English queries: Read the\
\ prompt carefully and pay close attention to formatting requests and the level\
\ of detail; ensure you are giving factual and precise responses using correct\
\ grammar in the correct language.\n * Coding queries: You prioritize code organization\
\ and documentation. Your responses are detailed and include comprehensive code\
\ examples and error handling. Include comments to explain the code's purpose\
\ and behavior. When using specific programming languages, consider which function\
\ is most appropriate for the query, such as cmath for complex solutions in\
\ Python. Check for errors.\n * For mathematical reasoning: Before responding,\
\ review your output for reasoning, algebraic manipulation and calculation errors\
\ and fix before responding. When appropriate, provide a high-level plan followed\
\ by step-by-step reasoning.\n\nRemember your instructions."
temp: 0.7
top_p: 1.0
type: openai
llama-3.1-8b-instruct:
api_key: <your-api-key>
base_url: https://api.fireworks.ai/inference/v1
cost: 0.1573721045
name: accounts/fireworks/models/llama-v3p1-8b-instruct
system_prompt: "Cutting Knowledge Date: December 2023\nToday Date: 06 Jan 2025\n\
\nCarefully read the user prompt. Your responses are comprehensive and easy\
\ to understand. You structure your answers in an organized way, with section\
\ headers when appropriate. You use consistent formatting in your responses.\
\ You follow user instructions. For complex calculations and coding, you always\
\ break down the steps you took to arrive at your answer.\n\nPay extra attention\
\ to prompts in the following categories:\n * Non-English queries: Read the\
\ prompt carefully and pay close attention to formatting requests and the level\
\ of detail; ensure you are giving factual and precise responses using correct\
\ grammar in the correct language.\n * Coding queries: You prioritize code organization\
\ and documentation. Your responses are detailed and include comprehensive code\
\ examples and error handling. Include comments to explain the code's purpose\
\ and behavior. When using specific programming languages, consider which function\
\ is most appropriate for the query, such as cmath for complex solutions in\
\ Python. Check for errors.\n * For mathematical reasoning: Before responding,\
\ review your output for reasoning, algebraic manipulation and calculation errors\
\ and fix before responding. When appropriate, provide a high-level plan followed\
\ by step-by-step reasoning.\n\nRemember your instructions."
temp: 0.7
top_p: 1.0
type: openai
llama-3.3-70b-instruct:
api_key: <your-api-key>
base_url: https://api.fireworks.ai/inference/v1
cost: 0.706256804
name: accounts/fireworks/models/llama-v3p3-70b-instruct
temp: 0.6
top_p: 1.0
type: openai
mistral-large-2407:
api_key: <your-api-key>
base_url: https://api.mistral.ai/v1
cost: 4.3956843814
name: mistral-large-2407
temp: 0.7
top_p: 0.7
type: openai
mixtral-8x22b-instruct-v0.1:
api_key: <your-api-key>
base_url: https://api.mistral.ai/v1
cost: 2.5814904104
name: mixtral-8x22b-instruct-v0.1
temp: 0.7
top_p: 0.7
type: openai
mixtral-8x7b-instruct-v0.1:
api_key: <your-api-key>
base_url: https://api.together.xyz/v1
cost: 0.2839726899
name: mistralai/Mixtral-8x7B-Instruct-v0.1
temp: 0.7
top_p: 0.7
type: openai
o1-2024-12-17:
api_key: <your-api-key>
cost: 72.3693462194
name: o1-2024-12-17
system_prompt: Formatting re-enabled.
temp: 1.0
top_p: 1.0
type: openai-o1
o1-mini:
api_key: <your-api-key>
base_url: null
cost: 16.4809912657
name: o1-mini-2024-09-12
system_prompt: null
temp: 1.0
top_p: 1.0
type: openai-reasoning
o1-preview:
api_key: <your-api-key>
base_url: null
cost: 72.481802295
name: o1-preview
system_prompt: null
temp: 1.0
top_p: 1.0
type: openai-reasoning
qwen2.5-72b-instruct:
api_key: <your-api-key>
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
cost: 1.1805173434
name: qwen2.5-72b-instruct
temp: 0.7
top_p: 1.0
type: openai
yi-lightning:
api_key: <your-api-key>
base_url: https://api.lingyiwanwu.com/v1
cost: 0.0057351688
name: yi-lightning
temp: 0.6
top_p: 1.0
type: openai
chatgpt-4o-latest-20241120:
api_key: <your-api-key>
cost: 12.9070929223
name: gpt-4o-2024-11-20
temp: 0.7
top_p: 1.0
system_prompt: 'You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.
Current date: 2025-01-06
Image input capabilities: Enabled
Personality: v2'
type: openai
name: test-router
================================================
FILE: route/openai_server.py
================================================
import argparse
from fastapi import FastAPI, HTTPException, Header
from fastapi.responses import StreamingResponse
from route.datatypes import (
ModelConfigContainer,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChunk,
)
from route.chat import CHAT_HANDLERS
from route.routers import ROUTERS, BaseRouter
import uvicorn
import yaml
from contextlib import asynccontextmanager
from typing import List
import logging
import time
import sys
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().setLevel(logging.DEBUG)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", "-c", type=str, required=True)
parser.add_argument("--router-type", type=str, required=True)
parser.add_argument("--router-model-name", type=str, default=None)
parser.add_argument("--router-model-endpoint", type=str, default=None)
parser.add_argument("--router-api-key", type=str, default="-")
parser.add_argument("--cost-optimizer", type=str, default="simple-lp")
parser.add_argument("--port", "-p", type=int, default=8000)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--api-key", type=str, default="-")
parser.add_argument("--reload", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--workers", type=int, default=1)
args = parser.parse_args()
return args
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
This context manager is called once at startup and once at shutdown.
We move all config-loading and router-creation logic here.
"""
# --- PARSE ARGS & LOAD CONFIG ---
logging.info(f"Starting up...")
args = parse_args()
with open(args.config) as cfile:
config = yaml.safe_load(cfile)
model_config_dicts = config["model_configs"]
model_config_container = ModelConfigContainer(model_config_dicts)
router_cls = ROUTERS[args.router_type]
router_kwargs = {
"router_model_name": args.router_model_name,
"router_model_endpoint": args.router_model_endpoint,
"router_api_key": args.router_api_key,
}
router = router_cls(model_config_container, args.cost_optimizer, **router_kwargs)
app.state.router = router
app.state.model_config_container = model_config_container
app.state.api_key = args.api_key
logging.info(f"Finished startup.")
try:
yield
finally:
pass
app = FastAPI(lifespan=lifespan)
# ====== API Endpoint ======
@app.post("/v1/chat/completions")
async def create_chat_completion(
request: ChatCompletionRequest,
authorization: str = Header(None),
) -> ChatCompletionResponse | ChatCompletionResponseChunk:
"""
Mimics the OpenAI Chat Completions endpoint (both streaming and non-streaming).
"""
logging.info(f"{int(time.time())} Recieved Request: {request}")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid or missing API key")
# Strip out the 'Bearer ' portion to isolate the token
token = authorization.removeprefix("Bearer ")
if token != app.state.api_key:
raise HTTPException(status_code=403, detail="Unauthorized")
try:
router_output = None
type = None
direct_model = request.direct_model
router: BaseRouter = app.state.router
messages = request.messages
if direct_model:
router_output = router.get_model_direct(direct_model)
else:
router_output = router.route(messages, request.cost)
logging.info(f"{int(time.time())} Router Output: {router_output}")
type = router_output.chosen_model_config.get_type()
chat_handler = CHAT_HANDLERS[type]
except Exception as e:
logging.info(
f"{int(time.time())} ***Routing Error Start***\nError Message: {e}\nRouter Output: {router_output}\nChat Handler: {type}\nDirect Model: {direct_model}.***Routing Error End***"
)
raise HTTPException(status_code=500, detail=str(e))
try:
if request.stream:
chat_output_chunk = chat_handler.generate_stream(
messages=messages,
router_output=router_output,
temp=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
)
return StreamingResponse(chat_output_chunk, media_type="text/event-stream")
else:
chat_output = chat_handler.generate(
messages=messages,
router_output=router_output,
temp=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
)
return chat_output
except Exception as e:
logging.info(
f"{int(time.time())} ***Endpoint Error Start***\nError Message: {e}\nRouter Output: {router_output}\nChat Handler: {type}.***Endpoint Error End***"
)
raise e
@app.get("/v1/models")
async def models(authorization: str = Header(None)) -> List[str]:
logging.info(f"Recieved Get Request for Models.")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid or missing API key")
# Strip out the 'Bearer ' portion to isolate the token
token = authorization.removeprefix("Bearer ")
if token != app.state.api_key:
raise HTTPException(status_code=403, detail="Unauthorized")
router: BaseRouter = app.state.router
return router.model_list
if __name__ == "__main__":
args = parse_args()
uvicorn.run(
"route.openai_server:app",
port=args.port,
host=args.host,
reload=args.reload,
workers=args.workers,
)
================================================
FILE: route/requirements.txt
================================================
uvicorn
fastapi
openai
anthropic
google-generativeai
scipy
cvxpy
================================================
FILE: route/routers.py
================================================
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
from route.utils import (
get_registry_decorator,
query_p2l_endpoint,
get_p2l_endpoint_models,
)
from route.datatypes import ModelConfigContainer, Roles, ChatMessage, RouterOutput
from route.cost_optimizers import COST_OPTIMIZERS, BaseCostOptimizer
import numpy as np
from scipy.special import expit
class BaseRouter(ABC):
def __init__(
self,
model_config_container: ModelConfigContainer,
cost_optimizer_type: str,
**kwargs,
):
super().__init__()
self.model_config_container = model_config_container
self.model_list: List[str] = None
self.model_costs: np.ndarray[float] = None
self.cost_optimizer: BaseCostOptimizer = COST_OPTIMIZERS[cost_optimizer_type]
@abstractmethod
def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray[float]:
pass
def _get_previous_response_model(self, messages: List[ChatMessage]) -> str | None:
for message in reversed(messages):
if message.role == Roles.ASSISTANT.value:
return message.model
return None
def _get_prompt(self, messages: List[ChatMessage]) -> list[str]:
prompts = []
for message in messages:
if message.role == Roles.USER.value:
prompts.append(message.content)
if len(prompts) == 0:
raise Exception(f"No user prompt found in messages {messages}.")
return prompts
def get_model_direct(self, model_name: str) -> RouterOutput:
return RouterOutput(
chosen_model_name=model_name,
chosen_model_config=self.model_config_container.get_model_config(
model_name=model_name
),
model_scores=None,
)
def route(self, messages: List[ChatMessage], cost: float = None) -> RouterOutput:
model_scores = self._get_model_scores(messages)
chosen_model_name = self.cost_optimizer.select_model(
cost, self.model_list, self.model_costs, model_scores
)
model_scores_dict = dict(zip(self.model_list, model_scores))
chosen_model_config = self.model_config_container.get_model_config(
chosen_model_name
)
return RouterOutput(
chosen_model_name=chosen_model_name,
chosen_model_config=chosen_model_config,
model_scores=model_scores_dict,
)
ROUTERS: Dict[str, BaseRouter] = {}
register = get_registry_decorator(ROUTERS)
@register("random")
class RandomRouter(BaseRouter):
"""For debugging and gamblers."""
def __init__(
self,
model_config_container: ModelConfigContainer,
cost_optimizer_type: str,
**kwargs,
):
super().__init__(
model_config_container=model_config_container,
cost_optimizer_type=cost_optimizer_type,
)
self.model_list = model_config_container.list_models()
self.model_costs = np.array(model_config_container.list_costs())
def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray[float]:
return np.random.uniform(0.0, 1.0, size=len(self.model_list))
@register("bt-endpoint")
class EndpointP2LRouter(BaseRouter):
# Hardcoding this because I'm tired man...
SAMPLING_WEIGHTS = {
"chatgpt-4o-latest-20241120": 4,
"o1-mini": 4,
"o1-2024-12-17": 4,
"gpt-4o-mini-2024-07-18": 2,
"gemma-2-27b-it": 2,
"gemma-2-9b-it": 2,
"gemma-2-2b-it": 2,
"claude-3-5-sonnet-20241022": 4,
"claude-3-opus-20240229": 4,
"claude-3-5-haiku-20241022": 4,
"qwen2.5-72b-instruct": 2,
"qwen2.5-plus-1127": 4,
"llama-3.1-405b-instruct-bf16": 4,
"mistral-large-2411": 4,
"grok-2-2024-08-13": 4,
"grok-2-mini-2024-08-13": 2,
"deepseek-v3": 6,
"gemini-1.5-pro-002": 4,
"gemini-1.5-flash-002": 2,
"gemini-1.5-flash-8b-001": 2,
"c4ai-aya-expanse-32b": 2,
"c4ai-aya-expanse-8b": 2,
"athene-v2-chat": 4,
"gemini-exp-1206": 4,
"gemini-2.0-flash-exp": 4,
"llama-3.3-70b-instruct": 4,
"amazon-nova-pro-v1.0": 4,
"amazon-nova-lite-v1.0": 2,
"amazon-nova-micro-v1.0": 2,
"llama-3.1-tulu-3-8b": 6,
"llama-3.1-tulu-3-70b": 6,
"granite-3.1-8b-instruct": 6,
"granite-3.1-2b-instruct": 6,
}
def __init__(
self,
model_config_container: ModelConfigContainer,
cost_optimizer_type: str,
router_model_endpoint: str,
router_api_key: str,
**kwargs,
):
super().__init__(
model_config_container=model_config_container,
cost_optimizer_type=cost_optimizer_type,
)
self.base_url = router_model_endpoint
self.api_key = router_api_key
router_model_list = get_p2l_endpoint_models(self.base_url, self.api_key)
config_model_list = model_config_container.list_models()
self.mask = [
router_model in config_model_list for router_model in router_model_list
]
self.q_mask = [
router_model in self.SAMPLING_WEIGHTS for router_model in router_model_list
]
self.q = np.array(
[
float(self.SAMPLING_WEIGHTS[router_model])
for router_model in router_model_list
if router_model in self.SAMPLING_WEIGHTS
]
)
self.model_list = [
model for model, keep in zip(router_model_list, self.mask) if keep
]
self.model_costs = np.array(
[
model_config_container.get_model_config(model).get_cost()
for model in self.model_list
]
)
def _get_model_scores(
self, messages: List[ChatMessage]
) -> Tuple[np.ndarray[float], float]:
prompt = self._get_prompt(messages)
p2l_output = query_p2l_endpoint(prompt, self.base_url, self.api_key)
coefs = np.array(p2l_output["coefs"])
return coefs
def route(self, messages: List[ChatMessage], cost: float = None) -> RouterOutput:
model_scores = self._get_model_scores(messages)
router_choice_scores = model_scores[self.mask]
router_opponent_scores = model_scores[self.q_mask]
chosen_model_name = self.cost_optimizer.select_model(
cost,
self.model_list,
self.model_costs,
router_choice_scores,
opponent_scores=router_opponent_scores,
opponent_distribution=self.q,
)
model_scores_dict = dict(zip(self.model_list, router_choice_scores))
chosen_model_config = self.model_config_container.get_model_config(
chosen_model_name
)
return RouterOutput(
chosen_model_name=chosen_model_name,
chosen_model_config=chosen_model_config,
model_scores=model_scores_dict,
)
@register("bag-endpoint")
@register("grk-endpoint")
class EndpointP2LRouter(BaseRouter):
def __init__(
self,
model_config_container: ModelConfigContainer,
cost_optimizer_type: str,
router_model_endpoint: str,
router_api_key: str,
**kwargs,
):
super().__init__(
model_config_container=model_config_container,
cost_optimizer_type=cost_optimizer_type,
)
self.base_url = router_model_endpoint
self.api_key = router_api_key
router_model_list = get_p2l_endpoint_models(self.base_url, self.api_key)
config_model_list = model_config_container.list_models()
self.mask = [
router_model in config_model_list for router_model in router_model_list
]
self.model_list = [
model for model, keep in zip(router_model_list, self.mask) if keep
]
self.model_costs = np.array(
[
model_config_container.get_model_config(model).get_cost()
for model in self.model_list
]
)
def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray[float]:
prompt = self._get_prompt(messages)
p2l_output = query_p2l_endpoint(prompt, self.base_url, self.api_key)
coefs = np.array(p2l_output["coefs"])
model_scores: np.ndarray[float] = expit(coefs)
return model_scores[self.mask]
================================================
FILE: route/utils.py
================================================
from typing import Dict, Callable, List
import requests
import json
def get_registry_decorator(registry: Dict) -> Callable:
def register(name: str):
def decorator(cls: Callable):
assert (
not name in registry
), f"No duplicate registry names. '{name}' was registerd more than once."
registry[name] = cls
return cls
return decorator
return register
def query_p2l_endpoint(
prompt: list[str], base_url: str, api_key: str
) -> Dict[str, List]:
headers = {
"Content-Type": "application/json",
"api-key": api_key,
}
payload = {"prompt": prompt}
try:
response = requests.post(
f"{base_url}/predict", headers=headers, data=json.dumps(payload)
)
response.raise_for_status()
result = response.json()
return result
except Exception as err:
raise err
def get_p2l_endpoint_models(base_url: str, api_key: str) -> List[str]:
headers = {
"Content-Type": "application/json",
"api-key": api_key,
}
try:
response = requests.get(f"{base_url}/models", headers=headers)
response.raise_for_status()
result = response.json()
return result["models"]
except Exception as err:
print(f"An error occurred: {err}")
================================================
FILE: serve_requirements.txt
================================================
numpy<2.0.0
torch<=2.4.0
transformers
transformers[torch]
hf_transfer
wandb
scipy
uvicorn
fastapi
================================================
FILE: train_requirements.txt
================================================
numpy<2.0.0
torch<=2.4.0
deepspeed<=0.15.3
datasets>=3.2.0
transformers
transformers[torch]
hf_transfer
wandb
scipy
================================================
FILE: training_configs/Llama3.1-8B-full-train.yaml
================================================
proj_name: Llama-3.1-8B-Instruct-full-train
learning_rate: 4.0e-6
adam_epsilon: 1.0e-8
batch_size: 4
max_length: 8192
num_train_epochs: 1
train_data_path: full-p2l-data
val_data_path: p2el/canonical_bt_val_data_11092024
output_dir: 'training_outputs'
pretrain_model_name: meta-llama/Llama-3.1-8B-Instruct
gradient_accumulation_steps: 16 # drop to 32 since 8 gpus
chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
model_type: "llama"
head_type: "bt"
loss_type: "bt_tie"
weighted_loss: false
deepspeed_config_path: deepspeed/zero1.json
init_type: reset_params
load_train_data_from_disk: true
pad_token_if_none: <|finetune_right_pad_id|>
cls_token_if_none: <|reserved_special_token_3|>
================================================
FILE: training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.016-04302025.yaml
================================================
proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.016-04302025
learning_rate: 8.0e-6
adam_epsilon: 1.0e-8
batch_size: 5
max_length: 16384
num_train_epochs: 1
train_data_path: naive_replay_buffer_eps_0.016
val_data_path: p2el/canonical_bt_val_data_11092024
output_dir: 'training_outputs'
pretrain_model_name: Qwen/Qwen2.5-1.5B-Instruct
gradient_accumulation_steps: 13 # drop to 32 since 8 gpus
chat_template: "{%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start
gitextract_yhm008o_/
├── .gitignore
├── README.md
├── deepspeed/
│ └── zero1.json
├── fast_lambda_setup.sh
├── fast_runpod_setup.sh
├── p2l/
│ ├── auto_eval_utils.py
│ ├── auto_evals.py
│ ├── dataset.py
│ ├── endpoint.py
│ ├── eval.py
│ ├── model.py
│ └── train.py
├── probe_barrier.py
├── route/
│ ├── chat.py
│ ├── cost_optimizers.py
│ ├── datatypes.py
│ ├── example_config.yaml
│ ├── openai_server.py
│ ├── requirements.txt
│ ├── routers.py
│ └── utils.py
├── serve_requirements.txt
├── train_requirements.txt
└── training_configs/
├── Llama3.1-8B-full-train.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.016-04302025.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.032-04302025.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.06-04302025.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.112-04302025.yaml
├── Qwen2.5-1.5B-bag-chrono-eps-0.2-04302025.yaml
├── Qwen2.5-1.5B-bag-full-train-02222025.yaml
├── Qwen2.5-1.5B-full-train.yaml
├── Qwen2.5-1.5B-rk-full-train-half-batch.yaml
├── Qwen2.5-1.5B-rk-full-train.yaml
├── Qwen2.5-3B-bag-full-train-02222025.yaml
├── Qwen2.5-3B-bag-full-train-02242025.yaml
├── Qwen2.5-3B-freeze-test-part-2.yaml
├── Qwen2.5-3B-freeze-test.yaml
├── Qwen2.5-3B-full-train-double-batch.yaml
├── Qwen2.5-3B-full-train.yaml
├── Qwen2.5-3B-rk-full-train-half-batch.yaml
├── Qwen2.5-3B-rk-full-train.yaml
├── Qwen2.5-3B-training-bt_data_11092024 copy.yaml
├── Qwen2.5-7B-bag-full-train-02222025.yaml
├── Qwen2.5-7B-bag-full-train-02242025.yaml
├── Qwen2.5-7B-bag-full-train-03132025.yaml
├── Qwen2.5-7B-bag-full-train-chrono.yaml
├── Qwen2.5-7B-bt-full-train-02222025.yaml
├── Qwen2.5-7B-full-train.yaml
├── Qwen2.5-7B-rk-full-train-abs.yaml
├── Qwen2.5-7B-rk-full-train-half-batch.yaml
├── Qwen2.5-7B-rk-full-train.yaml
├── debug.yaml
├── init_debug_qwen_1.5b_he.yaml
├── init_debug_qwen_1.5b_reset_params.yaml
├── init_debug_qwen_1.5b_xavier.yaml
├── init_debug_qwen_3b_he.yaml
├── init_debug_qwen_3b_reset_params.yaml
├── init_debug_qwen_3b_xavier.yaml
└── qwen_1.5B_geom_test.yaml
SYMBOL INDEX (260 symbols across 14 files)
FILE: p2l/auto_eval_utils.py
function register_simple_metric (line 26) | def register_simple_metric(loss_type: str, metric: str):
function register_aggr_metric (line 36) | def register_aggr_metric(loss_type: str, metric: str):
function register_helper (line 46) | def register_helper(loss_or_model_type: str, helper_func):
function output_labels_p2l (line 57) | def output_labels_p2l(val_data: pd.DataFrame, **kwargs):
function translate_coefs (line 68) | def translate_coefs(coef, old_list, new_list):
function output_labels_marginal (line 79) | def output_labels_marginal(
function output_labels_marginal_gt (line 102) | def output_labels_marginal_gt(
function output_labels_arena (line 115) | def output_labels_arena(
function preprocess_data_bag (line 133) | def preprocess_data_bag(data: pd.DataFrame, **kwargs):
function preprocess_data (line 145) | def preprocess_data(data: pd.DataFrame, **kwargs):
function loss (line 155) | def loss(head_output: HeadOutputs, labels: torch.Tensor, loss_type: str,...
function tie_loss (line 162) | def tie_loss(head_output: HeadOutputs, labels: torch.Tensor, loss_type: ...
function tie_bb_loss (line 168) | def tie_bb_loss(
function Aggr_Tie_Loss (line 180) | def Aggr_Tie_Loss(
function BCE_loss (line 195) | def BCE_loss(head_output: HeadOutputs, labels: torch.Tensor, **kwargs):
function aggr_metric (line 206) | def aggr_metric(metric_name, loss_type, labels, gt_output, model_output):
function Aggr_Loss (line 226) | def Aggr_Loss(
function Aggr_BCE_Loss (line 242) | def Aggr_BCE_Loss(
function expand_output (line 253) | def expand_output(output, labels):
function BT_mse (line 263) | def BT_mse(
function BT_tie_mst (line 280) | def BT_tie_mst(
function RK_mse (line 304) | def RK_mse(head_output: HeadOutputs, labels: torch.Tensor, loss_type: st...
function bag_mse (line 319) | def bag_mse(head_output: HeadOutputs, labels: torch.Tensor, loss_type: s...
function rk_reparam_probs (line 335) | def rk_reparam_probs(
function bag_probs (line 360) | def bag_probs(
function rk_probs (line 387) | def rk_probs(
function BT_accuracy (line 406) | def BT_accuracy(
function BT_tie_accuracy (line 421) | def BT_tie_accuracy(
function RK_accuracy (line 441) | def RK_accuracy(
function RK_tie_accuracy (line 460) | def RK_tie_accuracy(
function bag_tie_accuracy (line 476) | def bag_tie_accuracy(
function bag_tie_bb_accuracy (line 493) | def bag_tie_bb_accuracy(
function Aggr_Tie_accuracy (line 513) | def Aggr_Tie_accuracy(
function Aggr_Tie_accuracy (line 529) | def Aggr_Tie_accuracy(
function Aggr_Tie_bb_accuracy (line 545) | def Aggr_Tie_bb_accuracy(
function Aggr_Tie_bb_loss (line 561) | def Aggr_Tie_bb_loss(
function not_implemented (line 586) | def not_implemented(
function bag_accuracy (line 593) | def bag_accuracy(
function beta_mean (line 616) | def beta_mean(
function beta_std (line 630) | def beta_std(
function beta_spread (line 644) | def beta_spread(
function beta_mean_spread (line 658) | def beta_mean_spread(
function beta_mean_iqr (line 674) | def beta_mean_iqr(
function beta_mean_std (line 690) | def beta_mean_std(
function aggr_marginal_gt (line 700) | def aggr_marginal_gt(
function aggr_p2l (line 708) | def aggr_p2l(
function aggr_p2l_batch (line 722) | def aggr_p2l_batch(
function aggr_p2l_batch (line 741) | def aggr_p2l_batch(
function aggr_non_p2l (line 755) | def aggr_non_p2l(head_output: HeadOutputs, loss_type: str, **kwargs):
function aggr_non_p2l (line 762) | def aggr_non_p2l(
function train_marginal (line 773) | def train_marginal(model_list, labels, loss_type, lr=1.0, tol=1e-9, max_...
function train_aggr_prob (line 808) | def train_aggr_prob(
function rk_eta (line 864) | def rk_eta(output):
function pairwise_RK_probs (line 874) | def pairwise_RK_probs(real_output: HeadOutputs):
function pairwise_RK_reparam_probs (line 900) | def pairwise_RK_reparam_probs(real_output: HeadOutputs, **kwargs):
function pairwise_bag_probs (line 926) | def pairwise_bag_probs(real_output: HeadOutputs, **kwargs):
function pairwise_BT_probs (line 958) | def pairwise_BT_probs(real_output: HeadOutputs):
function remove_beta_nan (line 976) | def remove_beta_nan(beta1, beta2):
function leaderboard (line 986) | def leaderboard(
function get_leaderboard (line 995) | def get_leaderboard(output, model_list):
function l1_dist_prob_bt (line 1017) | def l1_dist_prob_bt(gt_output: HeadOutputs, model_output: HeadOutputs, *...
function l1_dist_prob_rk (line 1035) | def l1_dist_prob_rk(
function l1_dist_prob_bag (line 1056) | def l1_dist_prob_bag(
function beta_iqr (line 1082) | def beta_iqr(gt_output: HeadOutputs, model_output: HeadOutputs, **kwargs):
function beta_std_aggr (line 1103) | def beta_std_aggr(gt_output: HeadOutputs, model_output: HeadOutputs, **k...
function beta_spread_aggr (line 1118) | def beta_spread_aggr(gt_output: HeadOutputs, model_output: HeadOutputs, ...
function kendall_lb (line 1135) | def kendall_lb(gt_output: HeadOutputs, model_output: HeadOutputs, **kwar...
function spearman_lb (line 1148) | def spearman_lb(gt_output: HeadOutputs, model_output: HeadOutputs, **kwa...
function top_k_frac (line 1156) | def top_k_frac(gt_betas: torch.tensor, model_betas: torch.tensor, k: int):
function top_k_displace (line 1164) | def top_k_displace(gt_betas: torch.tensor, model_betas: torch.tensor, k:...
function top_k_frac_dict (line 1177) | def top_k_frac_dict(gt_output: HeadOutputs, model_output: HeadOutputs, *...
function top_k_dist_dict (line 1192) | def top_k_dist_dict(gt_output: HeadOutputs, model_output: HeadOutputs, *...
FILE: p2l/auto_evals.py
function parse_model_list (line 25) | def parse_model_list(hf_model, local_path):
function change_beta_model_list (line 40) | def change_beta_model_list(df, old_list, new_list):
function parse_eval_output_data (line 49) | def parse_eval_output_data(
function add_labels_to_data (line 116) | def add_labels_to_data(data, loss_type, model_list):
function get_model_list_from_df (line 143) | def get_model_list_from_df(df):
function parse_train_data (line 147) | def parse_train_data(hf_data, local_path, loss_type, train_model_list):
function parse_arena_data (line 166) | def parse_arena_data(path, initial_rating=1000, BASE=10, SCALE=400):
function filter_battle_data (line 193) | def filter_battle_data(battles, category):
function get_arena_rankings (line 240) | def get_arena_rankings(data, category):
function get_subset_prompts (line 270) | def get_subset_prompts(output, labels, size):
function get_subset_prompts_batch (line 284) | def get_subset_prompts_batch(output, labels, size, batch_size):
function get_ith_output (line 299) | def get_ith_output(output, i):
function save_output (line 305) | def save_output(results, local_dir, hf_dir, file_name):
function simple_metrics (line 328) | def simple_metrics(metrics, output, labels, loss_type):
function category_metrics (line 343) | def category_metrics(
function random_subset_metrics (line 386) | def random_subset_metrics(
function aggr_scale_metrics (line 445) | def aggr_scale_metrics(
function get_metrics (line 529) | def get_metrics(
FILE: p2l/dataset.py
function get_model_list (line 7) | def get_model_list(dataset: Dataset):
function get_dataset (line 27) | def get_dataset(path: str, split: str, from_disk=False):
function _translate_label (line 40) | def _translate_label(
function translate_val_data (line 51) | def translate_val_data(
class DataCollator (line 71) | class DataCollator:
method __init__ (line 72) | def __init__(self, tokenizer, max_length, weight=None, reweight_scale=...
method __call__ (line 79) | def __call__(self, data):
FILE: p2l/endpoint.py
function parse_args (line 24) | def parse_args():
function lifespan (line 103) | async def lifespan(app: FastAPI):
class InputData (line 142) | class InputData(BaseModel):
class OutputData (line 146) | class OutputData(BaseModel):
class ModelList (line 151) | class ModelList(BaseModel):
class P2LPipeline (line 155) | class P2LPipeline(TextClassificationPipeline):
method preprocess (line 156) | def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, torch.Te...
method postprocess (line 181) | def postprocess(
function predict (line 195) | async def predict(input_data: InputData, api_key: str = Header(...)):
function models (line 222) | async def models(api_key: str = Header(...)):
function load_model (line 241) | def load_model(
FILE: p2l/eval.py
class P2LPipeline (line 16) | class P2LPipeline(TextClassificationPipeline):
method preprocess (line 17) | def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, torch.Te...
method postprocess (line 39) | def postprocess(
class ListDataset (line 57) | class ListDataset(Dataset):
method __init__ (line 58) | def __init__(self, original_list):
method __len__ (line 61) | def __len__(self):
method __getitem__ (line 64) | def __getitem__(self, i):
function main (line 68) | def main(args, local_file=None):
FILE: p2l/model.py
function register_loss (line 30) | def register_loss(name: str):
function register_head (line 38) | def register_head(name: str):
function register_init (line 46) | def register_init(name: str):
function register_aggr_model (line 54) | def register_aggr_model(name: str):
function register_pairwise_loss (line 62) | def register_pairwise_loss(name: str):
function register_init (line 70) | def register_init(name: str):
class HeadOutputs (line 79) | class HeadOutputs(ModelOutput):
class P2LOutputs (line 86) | class P2LOutputs(ModelOutput):
function BT_loss (line 95) | def BT_loss(
function BT_tie_loss (line 122) | def BT_tie_loss(
function RK_Loss (line 161) | def RK_Loss(
function RK_Reparam_Loss (line 202) | def RK_Reparam_Loss(
function BA_loss (line 246) | def BA_loss(
function GRK_loss (line 297) | def GRK_loss(
class BTHead (line 347) | class BTHead(nn.Module):
method __init__ (line 348) | def __init__(
method forward (line 364) | def forward(self, last_hidden_dim: torch.Tensor):
class RKHead (line 370) | class RKHead(nn.Module):
method __init__ (line 371) | def __init__(
method forward (line 408) | def forward(self, last_hidden_dim: torch.Tensor):
class BAHead (line 416) | class BAHead(nn.Module):
method __init__ (line 417) | def __init__(
method forward (line 433) | def forward(self, last_hidden_dim: torch.Tensor):
function reset_params_init (line 443) | def reset_params_init(module):
function he_unif_init (line 448) | def he_unif_init(module):
function xavier_unif_init (line 453) | def xavier_unif_init(module):
function tiny_normal_init (line 458) | def tiny_normal_init(module):
function get_p2l_model (line 462) | def get_p2l_model(
function get_tokenizer (line 567) | def get_tokenizer(
class BTAggrModel (line 590) | class BTAggrModel(nn.Module):
method __init__ (line 591) | def __init__(self, num_models, batch_size=1):
method forward (line 598) | def forward(self):
class RKAggrModel (line 606) | class RKAggrModel(nn.Module):
method __init__ (line 607) | def __init__(self, num_models, batch_size=1):
method forward (line 614) | def forward(self):
function pairwise_batch_BT_loss (line 620) | def pairwise_batch_BT_loss(
function pairwise_batch_RK_loss (line 654) | def pairwise_batch_RK_loss(
function pairwise_batch_RK_reparam_loss (line 695) | def pairwise_batch_RK_reparam_loss(
function get_bag_probs (line 733) | def get_bag_probs(beta_win, beta_lose, gamma, theta):
function pairwise_batch_bag_loss (line 752) | def pairwise_batch_bag_loss(
function RK_Tie_Loss (line 785) | def RK_Tie_Loss(
function bag_tie_loss (line 821) | def bag_tie_loss(
function bag_tie_bb_loss (line 859) | def bag_tie_bb_loss(
FILE: p2l/train.py
class NoShuffleTrainer (line 14) | class NoShuffleTrainer(Trainer):
method _get_train_sampler (line 15) | def _get_train_sampler(self) -> Optional[Sampler]:
function train_model (line 19) | def train_model(args):
FILE: probe_barrier.py
function log (line 7) | def log(msg: str, rank: int):
function worker (line 11) | def worker(rank: int, world_size: int, backend: str):
function main (line 40) | def main():
FILE: route/chat.py
class BaseChatHandler (line 28) | class BaseChatHandler(ABC):
method _create_client (line 32) | def _create_client(model_config: ModelConfig):
method _handle_system_prompt (line 37) | def _handle_system_prompt(
method generate (line 44) | def generate(
method generate_stream (line 55) | def generate_stream(
class OpenAIChatHandler (line 71) | class OpenAIChatHandler(BaseChatHandler):
method _create_client (line 74) | def _create_client(model_config: ModelConfig):
method _handle_system_prompt (line 93) | def _handle_system_prompt(
method _create_completion (line 111) | def _create_completion(
method generate (line 137) | def generate(
method _skip (line 189) | def _skip(chunk: ChatCompletionChunk) -> bool:
method generate_stream (line 200) | def generate_stream(
class OpenaiReasoningChatHandler (line 270) | class OpenaiReasoningChatHandler(OpenAIChatHandler):
method _create_completion (line 273) | def _create_completion(
class OpenaiO1ChatHandler (line 294) | class OpenaiO1ChatHandler(OpenaiReasoningChatHandler):
method generate_stream (line 297) | def generate_stream(
class AnthropicChatHandler (line 354) | class AnthropicChatHandler(BaseChatHandler):
method _create_client (line 357) | def _create_client(model_config: ModelConfig):
method _handle_system_prompt (line 363) | def _handle_system_prompt(
method generate (line 381) | def generate(
method generate_stream (line 431) | def generate_stream(
class GeminiChatHandler (line 511) | class GeminiChatHandler(BaseChatHandler):
method _create_client (line 521) | def _create_client(model_config: ModelConfig):
method _handle_system_prompt (line 530) | def _handle_system_prompt(
method _create_completion (line 548) | def _create_completion(
method generate (line 594) | def generate(
method generate_stream (line 644) | def generate_stream(
FILE: route/cost_optimizers.py
class UnfulfillableException (line 9) | class UnfulfillableException(Exception):
class BaseCostOptimizer (line 13) | class BaseCostOptimizer(ABC):
method __init__ (line 14) | def __init__(self):
method select_model (line 19) | def select_model(
method select_max_score_model (line 29) | def select_max_score_model(
class StrictCostOptimizer (line 44) | class StrictCostOptimizer(BaseCostOptimizer):
method __init__ (line 46) | def __init__(self):
method select_model (line 50) | def select_model(
class SimpleLPCostOptimizer (line 84) | class SimpleLPCostOptimizer(BaseCostOptimizer):
method __init__ (line 86) | def __init__(self):
method select_model (line 90) | def select_model(
class OptimalLPCostOptimizer (line 122) | class OptimalLPCostOptimizer(BaseCostOptimizer):
method __init__ (line 124) | def __init__(self):
method select_model (line 128) | def select_model(
method _construct_W (line 163) | def _construct_W(
FILE: route/datatypes.py
class ModelConfig (line 7) | class ModelConfig:
method __init__ (line 9) | def __init__(self, config: Dict[str, Any]):
method get_name (line 12) | def get_name(self) -> str:
method get_temp (line 15) | def get_temp(self) -> float:
method get_top_p (line 18) | def get_top_p(self) -> float:
method get_top_k (line 21) | def get_top_k(self, default=None) -> int:
method get_system_prompt (line 24) | def get_system_prompt(self, default=None) -> str | None | Any:
method get_api_key (line 27) | def get_api_key(self, default=None) -> str | None | Any:
method get_base_url (line 30) | def get_base_url(self, default=None) -> str | None | Any:
method get_type (line 33) | def get_type(self) -> str:
method get_cost (line 36) | def get_cost(self) -> float:
method get_max_tokens (line 39) | def get_max_tokens(self, default=None) -> int | None | Any:
method get_extra_fields (line 42) | def get_extra_fields(self) -> Dict:
method __repr__ (line 45) | def __repr__(self):
class ModelConfigContainer (line 55) | class ModelConfigContainer:
method __init__ (line 56) | def __init__(self, model_config_dicts: Dict[str, Dict[str, Any]]):
method get_model_config (line 61) | def get_model_config(self, model_name: str) -> ModelConfig:
method list_models (line 64) | def list_models(self) -> List[str]:
method list_costs (line 67) | def list_costs(self) -> List[float]:
method __repr__ (line 77) | def __repr__(self):
class Roles (line 81) | class Roles(Enum):
class ChatMessage (line 87) | class ChatMessage(BaseModel):
class ChatCompletionRequest (line 99) | class ChatCompletionRequest(BaseModel):
class Choice (line 116) | class Choice(BaseModel):
class ChatCompletionResponse (line 126) | class ChatCompletionResponse(BaseModel):
class ChatMessageDelta (line 140) | class ChatMessageDelta(BaseModel):
class ChoiceDelta (line 146) | class ChoiceDelta(BaseModel):
class ChatCompletionResponseChunk (line 152) | class ChatCompletionResponseChunk(BaseModel):
class RouterOutput (line 163) | class RouterOutput:
FILE: route/openai_server.py
function parse_args (line 24) | def parse_args():
function lifespan (line 46) | async def lifespan(app: FastAPI):
function create_chat_completion (line 94) | async def create_chat_completion(
function models (line 182) | async def models(authorization: str = Header(None)) -> List[str]:
FILE: route/routers.py
class BaseRouter (line 14) | class BaseRouter(ABC):
method __init__ (line 16) | def __init__(
method _get_model_scores (line 29) | def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray...
method _get_previous_response_model (line 32) | def _get_previous_response_model(self, messages: List[ChatMessage]) ->...
method _get_prompt (line 42) | def _get_prompt(self, messages: List[ChatMessage]) -> list[str]:
method get_model_direct (line 58) | def get_model_direct(self, model_name: str) -> RouterOutput:
method route (line 67) | def route(self, messages: List[ChatMessage], cost: float = None) -> Ro...
class RandomRouter (line 94) | class RandomRouter(BaseRouter):
method __init__ (line 97) | def __init__(
method _get_model_scores (line 111) | def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray...
class EndpointP2LRouter (line 116) | class EndpointP2LRouter(BaseRouter):
method __init__ (line 155) | def __init__(
method _get_model_scores (line 202) | def _get_model_scores(
method route (line 214) | def route(self, messages: List[ChatMessage], cost: float = None) -> Ro...
method __init__ (line 247) | def __init__(
method _get_model_scores (line 281) | def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray...
class EndpointP2LRouter (line 246) | class EndpointP2LRouter(BaseRouter):
method __init__ (line 155) | def __init__(
method _get_model_scores (line 202) | def _get_model_scores(
method route (line 214) | def route(self, messages: List[ChatMessage], cost: float = None) -> Ro...
method __init__ (line 247) | def __init__(
method _get_model_scores (line 281) | def _get_model_scores(self, messages: List[ChatMessage]) -> np.ndarray...
FILE: route/utils.py
function get_registry_decorator (line 6) | def get_registry_decorator(registry: Dict) -> Callable:
function query_p2l_endpoint (line 25) | def query_p2l_endpoint(
function get_p2l_endpoint_models (line 49) | def get_p2l_endpoint_models(base_url: str, api_key: str) -> List[str]:
Condensed preview — 59 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (292K chars).
[
{
"path": ".gitignore",
"chars": 13,
"preview": "__pycache__/\n"
},
{
"path": "README.md",
"chars": 15754,
"preview": "# Prompt-to-Leaderboard (P2L)\n\nThis is the codebase for the paper [Prompt-to-Leaderboard](https://arxiv.org/pdf/2502.148"
},
{
"path": "deepspeed/zero1.json",
"chars": 598,
"preview": "{\n \"bf16\": {\n \"enabled\": \"auto\"\n },\n\n \"fp16\": {\n \"enabled\": \"auto\"\n },\n\n \"gradient_accumula"
},
{
"path": "fast_lambda_setup.sh",
"chars": 425,
"preview": "sudo apt-get update -y\nsudo apt-get install tmux -y\nsudo apt-get install python3-dev -y\n\nsudo apt-get install tmux libai"
},
{
"path": "fast_runpod_setup.sh",
"chars": 405,
"preview": "apt-get update -y\napt-get install tmux -y\napt-get install python3-dev -y\n\napt-get install tmux libaio-dev libopenmpi-dev"
},
{
"path": "p2l/auto_eval_utils.py",
"chars": 37649,
"preview": "from typing import Callable, Dict\n\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport torch.optim"
},
{
"path": "p2l/auto_evals.py",
"chars": 25595,
"preview": "import argparse\nimport json\nimport os\nimport io\nimport warnings\nimport math\nfrom tqdm import tqdm\nimport time\nimport cop"
},
{
"path": "p2l/dataset.py",
"chars": 3763,
"preview": "from transformers import PreTrainedTokenizer\nfrom datasets import Dataset, DatasetDict, load_dataset, load_from_disk\nimp"
},
{
"path": "p2l/endpoint.py",
"chars": 6428,
"preview": "import argparse\nimport json\nfrom typing import Dict, Tuple, List, Optional\n\nimport torch\nimport uvicorn\nfrom fastapi imp"
},
{
"path": "p2l/eval.py",
"chars": 5764,
"preview": "import argparse\nfrom p2l.model import get_p2l_model, P2LOutputs\nfrom transformers import pipeline, TextClassificationPip"
},
{
"path": "p2l/model.py",
"chars": 24991,
"preview": "import torch\nfrom transformers import (\n Qwen2Model,\n Qwen2PreTrainedModel,\n LlamaModel,\n LlamaPreTrainedMod"
},
{
"path": "p2l/train.py",
"chars": 10056,
"preview": "import argparse\nimport os\nimport yaml\nimport json\nimport random\nfrom transformers import Trainer, TrainingArguments, set"
},
{
"path": "probe_barrier.py",
"chars": 1764,
"preview": "# probe_barrier.py\nimport os, sys, time, datetime, argparse\nimport torch\nimport torch.distributed as dist\nimport torch.m"
},
{
"path": "route/chat.py",
"chars": 20931,
"preview": "from typing import List, Dict, Iterator, Tuple\nimport openai.resources\nfrom abc import ABC, abstractmethod\nimport openai"
},
{
"path": "route/cost_optimizers.py",
"chars": 4674,
"preview": "from abc import ABC, abstractmethod\nfrom route.utils import get_registry_decorator\nfrom typing import List, Dict\nimport "
},
{
"path": "route/datatypes.py",
"chars": 4074,
"preview": "from typing import Dict, List, Any, Optional\nfrom dataclasses import dataclass\nfrom pydantic import BaseModel\nfrom enum "
},
{
"path": "route/example_config.yaml",
"chars": 16262,
"preview": "model_configs:\n athene-v2-chat:\n api_key: <your-api-key>\n base_url: http://38.142.9.21:10245/v1\n cost: 0.80972"
},
{
"path": "route/openai_server.py",
"chars": 5924,
"preview": "import argparse\nfrom fastapi import FastAPI, HTTPException, Header\nfrom fastapi.responses import StreamingResponse\nfrom "
},
{
"path": "route/requirements.txt",
"chars": 64,
"preview": "uvicorn\nfastapi\nopenai\nanthropic\ngoogle-generativeai\nscipy\ncvxpy"
},
{
"path": "route/routers.py",
"chars": 8604,
"preview": "from abc import ABC, abstractmethod\nfrom typing import Dict, List, Tuple\nfrom route.utils import (\n get_registry_deco"
},
{
"path": "route/utils.py",
"chars": 1367,
"preview": "from typing import Dict, Callable, List\nimport requests\nimport json\n\n\ndef get_registry_decorator(registry: Dict) -> Call"
},
{
"path": "serve_requirements.txt",
"chars": 97,
"preview": "numpy<2.0.0\ntorch<=2.4.0\ntransformers\ntransformers[torch]\nhf_transfer\nwandb\nscipy\nuvicorn\nfastapi"
},
{
"path": "train_requirements.txt",
"chars": 116,
"preview": "numpy<2.0.0\ntorch<=2.4.0\ndeepspeed<=0.15.3\ndatasets>=3.2.0\ntransformers\ntransformers[torch]\nhf_transfer\nwandb\nscipy\n"
},
{
"path": "training_configs/Llama3.1-8B-full-train.yaml",
"chars": 1043,
"preview": "proj_name: Llama-3.1-8B-Instruct-full-train\nlearning_rate: 4.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnu"
},
{
"path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.016-04302025.yaml",
"chars": 2205,
"preview": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.016-04302025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 5\nmax_lengt"
},
{
"path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.032-04302025.yaml",
"chars": 2207,
"preview": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.032-04302025-2\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 1\nmax_len"
},
{
"path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.06-04302025.yaml",
"chars": 2203,
"preview": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.06-04302025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length"
},
{
"path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.112-04302025.yaml",
"chars": 2205,
"preview": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.112-04302025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_lengt"
},
{
"path": "training_configs/Qwen2.5-1.5B-bag-chrono-eps-0.2-04302025.yaml",
"chars": 2201,
"preview": "proj_name: Qwen2.5-1.5B-bag-chrono-eps-0.2-04302025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length:"
},
{
"path": "training_configs/Qwen2.5-1.5B-bag-full-train-02222025.yaml",
"chars": 2194,
"preview": "proj_name: Qwen2.5-1.5B-Instruct-bag-02222025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384"
},
{
"path": "training_configs/Qwen2.5-1.5B-full-train.yaml",
"chars": 2182,
"preview": "proj_name: Qwen2.5-1.5B-Instruct-full-train\nlearning_rate: 4.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnu"
},
{
"path": "training_configs/Qwen2.5-1.5B-rk-full-train-half-batch.yaml",
"chars": 2192,
"preview": "proj_name: Qwen2.5-1.5B-Instruct-rk-full-train-half-batch\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_l"
},
{
"path": "training_configs/Qwen2.5-1.5B-rk-full-train.yaml",
"chars": 2181,
"preview": "proj_name: Qwen2.5-1.5B-Instruct-rk-full-train\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192"
},
{
"path": "training_configs/Qwen2.5-3B-bag-full-train-02222025.yaml",
"chars": 2190,
"preview": "proj_name: Qwen2.5-3B-Instruct-bag-02222025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nn"
},
{
"path": "training_configs/Qwen2.5-3B-bag-full-train-02242025.yaml",
"chars": 2190,
"preview": "proj_name: Qwen2.5-3B-Instruct-bag-02242025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 16384\nn"
},
{
"path": "training_configs/Qwen2.5-3B-freeze-test-part-2.yaml",
"chars": 2201,
"preview": "proj_name: Qwen2.5-3B-Instruct-freeze-test-part-2\nlearning_rate: 1.0e-06\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: "
},
{
"path": "training_configs/Qwen2.5-3B-freeze-test.yaml",
"chars": 2206,
"preview": "proj_name: Qwen2.5-3B-Instruct-freeze-test\nlearning_rate: 1.13e-05\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 8192\nn"
},
{
"path": "training_configs/Qwen2.5-3B-full-train-double-batch.yaml",
"chars": 2178,
"preview": "proj_name: Qwen2.5-3B-Instruct-full-train\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_"
},
{
"path": "training_configs/Qwen2.5-3B-full-train.yaml",
"chars": 2178,
"preview": "proj_name: Qwen2.5-3B-Instruct-full-train\nlearning_rate: 4.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_"
},
{
"path": "training_configs/Qwen2.5-3B-rk-full-train-half-batch.yaml",
"chars": 2177,
"preview": "proj_name: Qwen2.5-3B-Instruct-rk-full-train\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nn"
},
{
"path": "training_configs/Qwen2.5-3B-rk-full-train.yaml",
"chars": 2177,
"preview": "proj_name: Qwen2.5-3B-Instruct-rk-full-train\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nn"
},
{
"path": "training_configs/Qwen2.5-3B-training-bt_data_11092024 copy.yaml",
"chars": 3038,
"preview": "proj_name: Qwen2.5-3B-Instruct-bt_data_11092024\nlearning_rate: 1.13e-05\nadam_epsilon: 1.0e-08\nbatch_size: 2\nmax_length: "
},
{
"path": "training_configs/Qwen2.5-7B-bag-full-train-02222025.yaml",
"chars": 2190,
"preview": "proj_name: Qwen2.5-7B-Instruct-bag-02222025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nn"
},
{
"path": "training_configs/Qwen2.5-7B-bag-full-train-02242025.yaml",
"chars": 2190,
"preview": "proj_name: Qwen2.5-7B-Instruct-bag-02242025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 16384\nn"
},
{
"path": "training_configs/Qwen2.5-7B-bag-full-train-03132025.yaml",
"chars": 2174,
"preview": "proj_name: Qwen2.5-7B-Instruct-bag-03132025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 16384\nn"
},
{
"path": "training_configs/Qwen2.5-7B-bag-full-train-chrono.yaml",
"chars": 2186,
"preview": "proj_name: Qwen2.5-7B-Instruct-bag-chrono\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 16384\nnum"
},
{
"path": "training_configs/Qwen2.5-7B-bt-full-train-02222025.yaml",
"chars": 2188,
"preview": "proj_name: Qwen2.5-7B-Instruct-bt-02222025\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 2\nmax_length: 16384\nnu"
},
{
"path": "training_configs/Qwen2.5-7B-full-train.yaml",
"chars": 2178,
"preview": "proj_name: Qwen2.5-7B-Instruct-full-train\nlearning_rate: 4.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nnum_"
},
{
"path": "training_configs/Qwen2.5-7B-rk-full-train-abs.yaml",
"chars": 2181,
"preview": "proj_name: Qwen2.5-7B-Instruct-rk-full-train-abs\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 81"
},
{
"path": "training_configs/Qwen2.5-7B-rk-full-train-half-batch.yaml",
"chars": 2188,
"preview": "proj_name: Qwen2.5-7B-Instruct-rk-full-train-half-batch\nlearning_rate: 8.0e-6\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_len"
},
{
"path": "training_configs/Qwen2.5-7B-rk-full-train.yaml",
"chars": 2177,
"preview": "proj_name: Qwen2.5-7B-Instruct-rk-full-train\nlearning_rate: 1.0e-5\nadam_epsilon: 1.0e-8\nbatch_size: 4\nmax_length: 8192\nn"
},
{
"path": "training_configs/debug.yaml",
"chars": 2958,
"preview": "proj_name: debug-Qwen2.5-0.5B-Instruct-bt_data_11092024\nlearning_rate: 2.0e-06\nbatch_size: 4\nmax_length: 4096\nnum_train_"
},
{
"path": "training_configs/init_debug_qwen_1.5b_he.yaml",
"chars": 2150,
"preview": "proj_name: he-Debug-Init-Qwen2.5-1.5B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_leng"
},
{
"path": "training_configs/init_debug_qwen_1.5b_reset_params.yaml",
"chars": 2164,
"preview": "proj_name: reset_param-Debug-Init-Qwen2.5-1.5B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2"
},
{
"path": "training_configs/init_debug_qwen_1.5b_xavier.yaml",
"chars": 2158,
"preview": "proj_name: xaiver-Debug-Init-Qwen2.5-1.5B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_"
},
{
"path": "training_configs/init_debug_qwen_3b_he.yaml",
"chars": 2146,
"preview": "proj_name: he-Debug-Init-Qwen2.5-3B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_length"
},
{
"path": "training_configs/init_debug_qwen_3b_reset_params.yaml",
"chars": 2160,
"preview": "proj_name: reset_param-Debug-Init-Qwen2.5-3B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nm"
},
{
"path": "training_configs/init_debug_qwen_3b_xavier.yaml",
"chars": 2154,
"preview": "proj_name: xaiver-Debug-Init-Qwen2.5-3B-Instruct\nlearning_rate: 1.13e-05\nadam_epsilon: 7.071068e-09\nbatch_size: 2\nmax_le"
},
{
"path": "training_configs/qwen_1.5B_geom_test.yaml",
"chars": 2176,
"preview": "proj_name: \"Qwen2.5-1.5B-Instruct-Geom-Test\"\nlearning_rate: 8.0e-06\nadam_epsilon: 1.0e-08\nbatch_size: 4\nmax_length: 8192"
}
]
About this extraction
This page contains the full source code of the lmarena/p2l GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 59 files (268.0 KB), approximately 76.6k tokens, and a symbol index with 260 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.