, line 1)\",\n",
" 'stdlib': [],\n",
" 'non_stdlib': [],\n",
" 'relative_imports': 0})"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ok, info = check_python_modules(\"def a\")\n",
"ok, info"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qgFNXORy-lpO"
},
"outputs": [],
"source": [
"def function_works(completions, **kwargs):\n",
" scores = []\n",
" for completion in completions:\n",
" score = 0\n",
" response = completion[0][\"content\"]\n",
" function = extract_function(response)\n",
" if function is not None:\n",
" ok, info = check_python_modules(function)\n",
" if function is None or \"error\" in info:\n",
" score = -2.0\n",
" else:\n",
" try:\n",
" new_strategy = create_locked_down_function(function)\n",
" score = 1.0\n",
" except:\n",
" score = -0.5\n",
" scores.append(score)\n",
" return scores"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gf69i2WT-m4K"
},
"source": [
"`no_cheating` checks if the function cheated since it might have imported Numpy or other functions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cUfHzCVx-nGK"
},
"outputs": [],
"source": [
"def no_cheating(completions, **kwargs):\n",
" scores = []\n",
" for completion in completions:\n",
" score = 0\n",
" response = completion[0][\"content\"]\n",
" function = extract_function(response)\n",
" if function is not None:\n",
" ok, info = check_python_modules(function)\n",
" scores.append(1.0 if ok else -20.0) # Penalize heavily!\n",
" else:\n",
" scores.append(-1.0) # Failed creating function\n",
" return scores"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "slnqWG3FTror"
},
"source": [
"Next `strategy_succeeds` checks if the strategy actually allows the game to terminate. Imagine if the strategy simply returned \"W\" which would fail after a time limit of 10 seconds.\n",
"\n",
"We also add a global `PRINTER` to print out the strategy and board state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sNi129lYTpZ2"
},
"outputs": [],
"source": [
"import numpy as np\n",
"global PRINTER\n",
"PRINTER = 0\n",
"def strategy_succeeds(completions, **kwargs):\n",
" global PRINTER\n",
" scores = []\n",
" # Generate a random game board with seed\n",
" seed = np.random.randint(10000)\n",
" for completion in completions:\n",
" printed = False\n",
" score = 0\n",
" response = completion[0][\"content\"]\n",
" function = extract_function(response)\n",
" if PRINTER % 5 == 0:\n",
" printed = True\n",
" print(function)\n",
" PRINTER += 1\n",
" if function is not None:\n",
" ok, info = check_python_modules(function)\n",
" if function is None or \"error\" in info:\n",
" scores.append(0)\n",
" continue\n",
" try:\n",
" new_strategy = create_locked_down_function(function)\n",
" except:\n",
" scores.append(0)\n",
" continue\n",
" try:\n",
" game = GameBoard(size = 6, seed = seed, target = 2048, probability_fours = 0.10)\n",
" steps, game_state = execute_strategy(new_strategy, game)\n",
" print(f\"Steps = {steps} State = {game_state}\")\n",
" if printed is False:\n",
" print(function)\n",
" print(game.board().pretty())\n",
" if game_state == \"success\":\n",
" scores.append(20.0) # Success - massively reward!\n",
" else:\n",
" scores.append(2.0) # Failed but function works!\n",
" except TimeoutError as e:\n",
" print(\"Timeout\")\n",
" scores.append(-1.0) # Failed with timeout\n",
" except Exception as e:\n",
" print(f\"Exception = {str(e)}\")\n",
" scores.append(-3.0) # Failed\n",
" return scores"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TCpSxtvSeAG_"
},
"source": [
"We'll now create the dataset which includes a replica of our prompt. Remember to add a reasoning effort of low! You can choose high reasoning mode, but this'll only work on more memory GPUs like H100s."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ldf6SjLHVPRv",
"outputId": "589f7523-9835-49b5-c477-4e1d8b0744ff"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"181\n"
]
},
{
"data": {
"text/plain": [
"{'prompt': [{'content': 'Create a new short 2048 strategy using only native Python code.\\nYou are given a list of list of numbers for the current board state.\\nOutput one action for \"W\", \"A\", \"S\", \"D\" on what is the optimal next step.\\nOutput your new short function in backticks using the format below:\\n```python\\ndef strategy(board):\\n return \"W\" # Example\\n```\\nAll helper functions should be inside def strategy. Only output the short function `strategy`.',\n",
" 'role': 'user'}],\n",
" 'answer': 0,\n",
" 'reasoning_effort': 'low'}"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from datasets import Dataset\n",
"dataset = Dataset.from_list([{\"prompt\" : [{\"role\": \"user\", \"content\": prompt.strip()}], \"answer\" : 0, \"reasoning_effort\": \"low\"}]*1000)\n",
"maximum_length = len(tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": prompt.strip()}], add_generation_prompt = True))\n",
"print(maximum_length)\n",
"dataset[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9-IOMhVg-2AM"
},
"source": [
"\n",
"### Train the model\n",
"\n",
"Now set up GRPO Trainer and all configurations! We also support GSPO, GAPO, Dr GRPO and more! Go the Unsloth [Reinforcement Learning Docs](https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide) for more options."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ptqkXK2D4d6p",
"outputId": "2061b833-5b98-4a2b-e7f5-4bc4652d8300"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\n",
"We will change the batch size of 1 to the `num_generations` of 2\n"
]
}
],
"source": [
"max_prompt_length = maximum_length + 1 # + 1 just in case!\n",
"max_completion_length = max_seq_length - max_prompt_length\n",
"\n",
"from trl import GRPOConfig, GRPOTrainer\n",
"training_args = GRPOConfig(\n",
" temperature = 1.0,\n",
" learning_rate = 5e-5,\n",
" weight_decay = 0.01,\n",
" warmup_ratio = 0.1,\n",
" lr_scheduler_type = \"linear\",\n",
" optim = \"adamw_8bit\",\n",
" logging_steps = 1,\n",
" per_device_train_batch_size = 1,\n",
" gradient_accumulation_steps = 1, # Increase to 4 for smoother training\n",
" num_generations = 2, # Decrease if out of memory\n",
" max_prompt_length = max_prompt_length,\n",
" max_completion_length = max_completion_length,\n",
" # num_train_epochs = 1, # Set to 1 for a full training run\n",
" max_steps = 1000,\n",
" save_steps = 100,\n",
" report_to = \"none\", # Can use Weights & Biases, TrackIO\n",
" output_dir = \"outputs\",\n",
"\n",
" # For optional training + evaluation\n",
" # fp16_full_eval = True,\n",
" # per_device_eval_batch_size = 4,\n",
" # eval_accumulation_steps = 1,\n",
" # eval_strategy = \"steps\",\n",
" # eval_steps = 1,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r9Mv8UZO5hz-"
},
"source": [
"And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n",
"\n",
"You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n",
"\n",
"| Step | Training Loss | reward | reward_std | completion_length | kl |\n",
"|------|---------------|-----------|------------|-------------------|----------|\n",
"| 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |\n",
"| 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |\n",
"| 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vzOuSVCL_GA9",
"outputId": "349f907c-cc67-4890-e131-397694679634"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Unsloth: Switching to float32 training since model cannot work with float16\n"
]
}
],
"source": [
"# For optional training + evaluation\n",
"# new_dataset = dataset.train_test_split(test_size = 0.01)\n",
"\n",
"trainer = GRPOTrainer(\n",
" model = model,\n",
" processing_class = tokenizer,\n",
" reward_funcs = [\n",
" function_works,\n",
" no_cheating,\n",
" strategy_succeeds,\n",
" ],\n",
" args = training_args,\n",
" train_dataset = dataset,\n",
"\n",
" # For optional training + evaluation\n",
" # train_dataset = new_dataset[\"train\"],\n",
" # eval_dataset = new_dataset[\"test\"],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fQhtuwP4cf34"
},
"source": [
"And let's train the model!\n",
"\n",
"**NOTE** A T4 free GPU might take 5 minutes for one generation sadly since it's an old GPU - A100 or H100 will be much faster!"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "VGRxPdSCcfC3",
"outputId": "f8bb720c-6d69-4f43-d9d1-a404842d2dff"
},
"outputs": [
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 199998, 'pad_token_id': 200017}.\n",
"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 2\n",
" \\\\ /| Num examples = 1,000 | Num Epochs = 1 | Total steps = 1,000\n",
"O^O/ \\_/ \\ Batch size per device = 2 | Gradient accumulation steps = 1\n",
"\\ / Data Parallel GPUs = 1 | Total batch size (2 x 1 x 1) = 2\n",
" \"-____-\" Trainable parameters = 1,990,656 of 20,916,747,840 (0.01% trained)\n",
"`generation_config` default values have been modified to match model-specific defaults: {'max_length': 131072}. If this is not desired, please set these values explicitly.\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"None\n",
"Steps = 1 State = failed\n",
"def strategy(board):\n",
" # simple heuristic: prefer right or down, then left, then up\n",
" for move in \"R D L U\".split():\n",
" pass\n",
"┌───┬───┬───┬───┬───┬───┐\n",
"│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"└───┴───┴───┴───┴───┴───┘\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [ 86/1000 8:06:01 < 88:08:29, 0.00 it/s, Epoch 0.09/1]\n",
"
\n",
" \n",
" \n",
" \n",
" | Step | \n",
" Training Loss | \n",
" reward | \n",
" reward_std | \n",
" completions / mean_length | \n",
" completions / min_length | \n",
" completions / max_length | \n",
" completions / clipped_ratio | \n",
" completions / mean_terminated_length | \n",
" completions / min_terminated_length | \n",
" completions / max_terminated_length | \n",
" kl | \n",
" rewards / function_works / mean | \n",
" rewards / function_works / std | \n",
" rewards / no_cheating / mean | \n",
" rewards / no_cheating / std | \n",
" rewards / strategy_succeeds / mean | \n",
" rewards / strategy_succeeds / std | \n",
"
\n",
" \n",
" \n",
" \n",
" | 1 | \n",
" 0.000000 | \n",
" 0.500000 | \n",
" 4.949748 | \n",
" 329.000000 | \n",
" 72.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 72.000000 | \n",
" 72.000000 | \n",
" 72.000000 | \n",
" 0.002197 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 1.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.000000 | \n",
" 0.500000 | \n",
" 4.949748 | \n",
" 550.500000 | \n",
" 515.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 515.000000 | \n",
" 515.000000 | \n",
" 515.000000 | \n",
" 0.000298 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 1.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 538.000000 | \n",
" 490.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 490.000000 | \n",
" 490.000000 | \n",
" 490.000000 | \n",
" 0.000276 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.000000 | \n",
" 2.500000 | \n",
" 2.121320 | \n",
" 325.000000 | \n",
" 120.000000 | \n",
" 530.000000 | \n",
" 0.000000 | \n",
" 325.000000 | \n",
" 120.000000 | \n",
" 530.000000 | \n",
" 0.000568 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 0.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 5 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 437.000000 | \n",
" 288.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 288.000000 | \n",
" 288.000000 | \n",
" 288.000000 | \n",
" 0.001381 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 6 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
" 308.500000 | \n",
" 301.000000 | \n",
" 316.000000 | \n",
" 0.000000 | \n",
" 308.500000 | \n",
" 301.000000 | \n",
" 316.000000 | \n",
" 0.000826 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -3.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 7 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 519.000000 | \n",
" 452.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 452.000000 | \n",
" 452.000000 | \n",
" 452.000000 | \n",
" 0.000223 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 8 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 333.500000 | \n",
" 81.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 81.000000 | \n",
" 81.000000 | \n",
" 81.000000 | \n",
" 0.001181 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 9 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 568.500000 | \n",
" 551.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 551.000000 | \n",
" 551.000000 | \n",
" 551.000000 | \n",
" 0.000281 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 10 | \n",
" 0.000000 | \n",
" -3.000000 | \n",
" 0.000000 | \n",
" 586.000000 | \n",
" 586.000000 | \n",
" 586.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000153 | \n",
" -2.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 11 | \n",
" 0.000000 | \n",
" 2.500000 | \n",
" 2.121320 | \n",
" 330.000000 | \n",
" 264.000000 | \n",
" 396.000000 | \n",
" 0.000000 | \n",
" 330.000000 | \n",
" 264.000000 | \n",
" 396.000000 | \n",
" 0.004015 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 0.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 12 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 374.500000 | \n",
" 360.000000 | \n",
" 389.000000 | \n",
" 0.000000 | \n",
" 374.500000 | \n",
" 360.000000 | \n",
" 389.000000 | \n",
" 0.000245 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 13 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 520.500000 | \n",
" 455.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 455.000000 | \n",
" 455.000000 | \n",
" 455.000000 | \n",
" 0.000915 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 14 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 406.500000 | \n",
" 227.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 227.000000 | \n",
" 227.000000 | \n",
" 227.000000 | \n",
" 0.007664 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 15 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 348.500000 | \n",
" 302.000000 | \n",
" 395.000000 | \n",
" 0.000000 | \n",
" 348.500000 | \n",
" 302.000000 | \n",
" 395.000000 | \n",
" 0.002411 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 16 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 408.000000 | \n",
" 379.000000 | \n",
" 437.000000 | \n",
" 0.000000 | \n",
" 408.000000 | \n",
" 379.000000 | \n",
" 437.000000 | \n",
" 0.002496 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 17 | \n",
" 0.000000 | \n",
" -12.500000 | \n",
" 13.435029 | \n",
" 493.000000 | \n",
" 400.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 400.000000 | \n",
" 400.000000 | \n",
" 400.000000 | \n",
" 0.009901 | \n",
" -2.000000 | \n",
" 0.000000 | \n",
" -10.500000 | \n",
" 13.435029 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 18 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 413.000000 | \n",
" 260.000000 | \n",
" 566.000000 | \n",
" 0.000000 | \n",
" 413.000000 | \n",
" 260.000000 | \n",
" 566.000000 | \n",
" 0.021275 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 19 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 487.500000 | \n",
" 389.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 389.000000 | \n",
" 389.000000 | \n",
" 389.000000 | \n",
" 0.019204 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 20 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 586.000000 | \n",
" 586.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 586.000000 | \n",
" 586.000000 | \n",
" 586.000000 | \n",
" 0.001022 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 21 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 397.500000 | \n",
" 276.000000 | \n",
" 519.000000 | \n",
" 0.000000 | \n",
" 397.500000 | \n",
" 276.000000 | \n",
" 519.000000 | \n",
" 0.027686 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 22 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 499.500000 | \n",
" 486.000000 | \n",
" 513.000000 | \n",
" 0.000000 | \n",
" 499.500000 | \n",
" 486.000000 | \n",
" 513.000000 | \n",
" 0.007218 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 23 | \n",
" 0.000000 | \n",
" -1.250000 | \n",
" 2.474874 | \n",
" 575.500000 | \n",
" 565.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 565.000000 | \n",
" 565.000000 | \n",
" 565.000000 | \n",
" 0.005928 | \n",
" -1.250000 | \n",
" 1.060660 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 24 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 563.500000 | \n",
" 541.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 541.000000 | \n",
" 541.000000 | \n",
" 541.000000 | \n",
" 0.008769 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 25 | \n",
" 0.000100 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 444.500000 | \n",
" 303.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 303.000000 | \n",
" 303.000000 | \n",
" 303.000000 | \n",
" 0.084963 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 26 | \n",
" 0.000100 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 419.000000 | \n",
" 252.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 252.000000 | \n",
" 252.000000 | \n",
" 252.000000 | \n",
" 0.114125 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 27 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 339.500000 | \n",
" 321.000000 | \n",
" 358.000000 | \n",
" 0.000000 | \n",
" 339.500000 | \n",
" 321.000000 | \n",
" 358.000000 | \n",
" 0.033457 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 28 | \n",
" 0.000100 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 372.500000 | \n",
" 311.000000 | \n",
" 434.000000 | \n",
" 0.000000 | \n",
" 372.500000 | \n",
" 311.000000 | \n",
" 434.000000 | \n",
" 0.081829 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 29 | \n",
" 0.000100 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 387.500000 | \n",
" 336.000000 | \n",
" 439.000000 | \n",
" 0.000000 | \n",
" 387.500000 | \n",
" 336.000000 | \n",
" 439.000000 | \n",
" 0.100017 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 30 | \n",
" 0.000100 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 463.000000 | \n",
" 410.000000 | \n",
" 516.000000 | \n",
" 0.000000 | \n",
" 463.000000 | \n",
" 410.000000 | \n",
" 516.000000 | \n",
" 0.095180 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 31 | \n",
" 0.000300 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 445.500000 | \n",
" 305.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 305.000000 | \n",
" 305.000000 | \n",
" 305.000000 | \n",
" 0.321803 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 32 | \n",
" 0.000300 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 425.000000 | \n",
" 310.000000 | \n",
" 540.000000 | \n",
" 0.000000 | \n",
" 425.000000 | \n",
" 310.000000 | \n",
" 540.000000 | \n",
" 0.335011 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 33 | \n",
" 0.000400 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 458.500000 | \n",
" 331.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 331.000000 | \n",
" 331.000000 | \n",
" 331.000000 | \n",
" 0.362238 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 34 | \n",
" 0.000500 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 347.500000 | \n",
" 207.000000 | \n",
" 488.000000 | \n",
" 0.000000 | \n",
" 347.500000 | \n",
" 207.000000 | \n",
" 488.000000 | \n",
" 0.518291 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 35 | \n",
" 0.000400 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 471.000000 | \n",
" 356.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 356.000000 | \n",
" 356.000000 | \n",
" 356.000000 | \n",
" 0.383606 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 36 | \n",
" 0.000700 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 393.000000 | \n",
" 200.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 200.000000 | \n",
" 200.000000 | \n",
" 200.000000 | \n",
" 0.674902 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 37 | \n",
" 0.000700 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 344.500000 | \n",
" 198.000000 | \n",
" 491.000000 | \n",
" 0.000000 | \n",
" 344.500000 | \n",
" 198.000000 | \n",
" 491.000000 | \n",
" 0.689294 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 38 | \n",
" 0.000600 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 473.500000 | \n",
" 361.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 361.000000 | \n",
" 361.000000 | \n",
" 361.000000 | \n",
" 0.607979 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 39 | \n",
" 0.000100 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 380.000000 | \n",
" 361.000000 | \n",
" 399.000000 | \n",
" 0.000000 | \n",
" 380.000000 | \n",
" 361.000000 | \n",
" 399.000000 | \n",
" 0.142165 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 40 | \n",
" 0.000300 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 386.500000 | \n",
" 352.000000 | \n",
" 421.000000 | \n",
" 0.000000 | \n",
" 386.500000 | \n",
" 352.000000 | \n",
" 421.000000 | \n",
" 0.293521 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 41 | \n",
" 0.000500 | \n",
" -10.500000 | \n",
" 16.263456 | \n",
" 107.500000 | \n",
" 89.000000 | \n",
" 126.000000 | \n",
" 0.000000 | \n",
" 107.500000 | \n",
" 89.000000 | \n",
" 126.000000 | \n",
" 0.465591 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" -9.500000 | \n",
" 14.849242 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 42 | \n",
" 0.000300 | \n",
" -0.250000 | \n",
" 1.060660 | \n",
" 410.000000 | \n",
" 373.000000 | \n",
" 447.000000 | \n",
" 0.000000 | \n",
" 410.000000 | \n",
" 373.000000 | \n",
" 447.000000 | \n",
" 0.314028 | \n",
" 0.250000 | \n",
" 1.060660 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 43 | \n",
" 0.000800 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 473.000000 | \n",
" 360.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 360.000000 | \n",
" 360.000000 | \n",
" 360.000000 | \n",
" 0.753577 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 44 | \n",
" 0.000400 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 528.500000 | \n",
" 471.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 471.000000 | \n",
" 471.000000 | \n",
" 471.000000 | \n",
" 0.370155 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 45 | \n",
" 0.000600 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 360.000000 | \n",
" 293.000000 | \n",
" 427.000000 | \n",
" 0.000000 | \n",
" 360.000000 | \n",
" 293.000000 | \n",
" 427.000000 | \n",
" 0.609444 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 46 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 581.500000 | \n",
" 577.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 577.000000 | \n",
" 577.000000 | \n",
" 577.000000 | \n",
" 0.021817 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 47 | \n",
" 0.000900 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 466.500000 | \n",
" 347.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 347.000000 | \n",
" 347.000000 | \n",
" 347.000000 | \n",
" 0.863071 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 48 | \n",
" 0.000700 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 495.000000 | \n",
" 404.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 404.000000 | \n",
" 404.000000 | \n",
" 404.000000 | \n",
" 0.727124 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 49 | \n",
" 0.000200 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 558.500000 | \n",
" 531.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 531.000000 | \n",
" 531.000000 | \n",
" 531.000000 | \n",
" 0.173142 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 50 | \n",
" 0.000100 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 477.000000 | \n",
" 465.000000 | \n",
" 489.000000 | \n",
" 0.000000 | \n",
" 477.000000 | \n",
" 465.000000 | \n",
" 489.000000 | \n",
" 0.089374 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 51 | \n",
" 0.001400 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 367.500000 | \n",
" 149.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 149.000000 | \n",
" 149.000000 | \n",
" 149.000000 | \n",
" 1.374907 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 52 | \n",
" 0.000900 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 458.500000 | \n",
" 331.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 331.000000 | \n",
" 331.000000 | \n",
" 331.000000 | \n",
" 0.929248 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 53 | \n",
" 0.000900 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 475.000000 | \n",
" 364.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 364.000000 | \n",
" 364.000000 | \n",
" 364.000000 | \n",
" 0.887930 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 54 | \n",
" 0.000100 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 439.000000 | \n",
" 424.000000 | \n",
" 454.000000 | \n",
" 0.000000 | \n",
" 439.000000 | \n",
" 424.000000 | \n",
" 454.000000 | \n",
" 0.126352 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 55 | \n",
" 0.000400 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 323.500000 | \n",
" 293.000000 | \n",
" 354.000000 | \n",
" 0.000000 | \n",
" 323.500000 | \n",
" 293.000000 | \n",
" 354.000000 | \n",
" 0.367167 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 56 | \n",
" 0.000400 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 543.000000 | \n",
" 500.000000 | \n",
" 586.000000 | \n",
" 0.000000 | \n",
" 543.000000 | \n",
" 500.000000 | \n",
" 586.000000 | \n",
" 0.375893 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 57 | \n",
" 0.000700 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 382.000000 | \n",
" 317.000000 | \n",
" 447.000000 | \n",
" 0.000000 | \n",
" 382.000000 | \n",
" 317.000000 | \n",
" 447.000000 | \n",
" 0.687571 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 58 | \n",
" 0.000600 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 506.000000 | \n",
" 426.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 426.000000 | \n",
" 426.000000 | \n",
" 426.000000 | \n",
" 0.648271 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 59 | \n",
" 0.001100 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 260.500000 | \n",
" 187.000000 | \n",
" 334.000000 | \n",
" 0.000000 | \n",
" 260.500000 | \n",
" 187.000000 | \n",
" 334.000000 | \n",
" 1.084255 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 60 | \n",
" 0.000200 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 523.500000 | \n",
" 495.000000 | \n",
" 552.000000 | \n",
" 0.000000 | \n",
" 523.500000 | \n",
" 495.000000 | \n",
" 552.000000 | \n",
" 0.198019 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 61 | \n",
" 0.001000 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 471.500000 | \n",
" 357.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 357.000000 | \n",
" 357.000000 | \n",
" 357.000000 | \n",
" 0.987108 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 62 | \n",
" 0.000400 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 532.000000 | \n",
" 478.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 478.000000 | \n",
" 478.000000 | \n",
" 478.000000 | \n",
" 0.428900 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 63 | \n",
" 0.000100 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
" 411.000000 | \n",
" 400.000000 | \n",
" 422.000000 | \n",
" 0.000000 | \n",
" 411.000000 | \n",
" 400.000000 | \n",
" 422.000000 | \n",
" 0.107686 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -3.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 64 | \n",
" 0.001000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 470.500000 | \n",
" 355.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 355.000000 | \n",
" 355.000000 | \n",
" 355.000000 | \n",
" 0.967091 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 65 | \n",
" 0.000300 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 553.000000 | \n",
" 520.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 520.000000 | \n",
" 520.000000 | \n",
" 520.000000 | \n",
" 0.262037 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 66 | \n",
" 0.000400 | \n",
" 2.500000 | \n",
" 2.121320 | \n",
" 471.500000 | \n",
" 423.000000 | \n",
" 520.000000 | \n",
" 0.000000 | \n",
" 471.500000 | \n",
" 423.000000 | \n",
" 520.000000 | \n",
" 0.414690 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 0.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 67 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 580.500000 | \n",
" 575.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 575.000000 | \n",
" 575.000000 | \n",
" 575.000000 | \n",
" 0.035250 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 68 | \n",
" 0.001200 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 435.000000 | \n",
" 284.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 284.000000 | \n",
" 284.000000 | \n",
" 284.000000 | \n",
" 1.168353 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 69 | \n",
" 0.000800 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 492.000000 | \n",
" 398.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 398.000000 | \n",
" 398.000000 | \n",
" 398.000000 | \n",
" 0.789415 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 70 | \n",
" 0.000700 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 291.500000 | \n",
" 240.000000 | \n",
" 343.000000 | \n",
" 0.000000 | \n",
" 291.500000 | \n",
" 240.000000 | \n",
" 343.000000 | \n",
" 0.723002 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 71 | \n",
" 0.001000 | \n",
" -10.500000 | \n",
" 16.263456 | \n",
" 407.000000 | \n",
" 301.000000 | \n",
" 513.000000 | \n",
" 0.000000 | \n",
" 407.000000 | \n",
" 301.000000 | \n",
" 513.000000 | \n",
" 0.958203 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" -9.500000 | \n",
" 14.849242 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 72 | \n",
" 0.000900 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 362.500000 | \n",
" 279.000000 | \n",
" 446.000000 | \n",
" 0.000000 | \n",
" 362.500000 | \n",
" 279.000000 | \n",
" 446.000000 | \n",
" 0.902191 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 73 | \n",
" 0.000100 | \n",
" 0.750000 | \n",
" 0.353553 | \n",
" 479.000000 | \n",
" 466.000000 | \n",
" 492.000000 | \n",
" 0.000000 | \n",
" 479.000000 | \n",
" 466.000000 | \n",
" 492.000000 | \n",
" 0.102604 | \n",
" 0.250000 | \n",
" 1.060660 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 74 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
" 579.000000 | \n",
" 572.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 572.000000 | \n",
" 572.000000 | \n",
" 572.000000 | \n",
" 0.049443 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -1.500000 | \n",
" 2.121320 | \n",
"
\n",
" \n",
" | 75 | \n",
" 0.000200 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 530.500000 | \n",
" 507.000000 | \n",
" 554.000000 | \n",
" 0.000000 | \n",
" 530.500000 | \n",
" 507.000000 | \n",
" 554.000000 | \n",
" 0.173276 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 76 | \n",
" 0.000500 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 401.000000 | \n",
" 353.000000 | \n",
" 449.000000 | \n",
" 0.000000 | \n",
" 401.000000 | \n",
" 353.000000 | \n",
" 449.000000 | \n",
" 0.522857 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 77 | \n",
" 0.000300 | \n",
" 0.750000 | \n",
" 0.353553 | \n",
" 512.500000 | \n",
" 473.000000 | \n",
" 552.000000 | \n",
" 0.000000 | \n",
" 512.500000 | \n",
" 473.000000 | \n",
" 552.000000 | \n",
" 0.271977 | \n",
" 0.250000 | \n",
" 1.060660 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 78 | \n",
" 0.000200 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 432.500000 | \n",
" 411.000000 | \n",
" 454.000000 | \n",
" 0.000000 | \n",
" 432.500000 | \n",
" 411.000000 | \n",
" 454.000000 | \n",
" 0.181327 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 79 | \n",
" 0.000200 | \n",
" 10.500000 | \n",
" 16.263456 | \n",
" 475.000000 | \n",
" 452.000000 | \n",
" 498.000000 | \n",
" 0.000000 | \n",
" 475.000000 | \n",
" 452.000000 | \n",
" 498.000000 | \n",
" 0.200004 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 8.500000 | \n",
" 16.263456 | \n",
"
\n",
" \n",
" | 80 | \n",
" 0.000600 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" 341.000000 | \n",
" 296.000000 | \n",
" 386.000000 | \n",
" 0.000000 | \n",
" 341.000000 | \n",
" 296.000000 | \n",
" 386.000000 | \n",
" 0.606937 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -2.000000 | \n",
" 1.414214 | \n",
"
\n",
" \n",
" | 81 | \n",
" 0.000200 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 456.500000 | \n",
" 428.000000 | \n",
" 485.000000 | \n",
" 0.000000 | \n",
" 456.500000 | \n",
" 428.000000 | \n",
" 485.000000 | \n",
" 0.235978 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 82 | \n",
" 0.000800 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 407.000000 | \n",
" 326.000000 | \n",
" 488.000000 | \n",
" 0.000000 | \n",
" 407.000000 | \n",
" 326.000000 | \n",
" 488.000000 | \n",
" 0.825952 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" 1.000000 | \n",
" 0.000000 | \n",
" -1.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 83 | \n",
" 0.000200 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 557.500000 | \n",
" 529.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 529.000000 | \n",
" 529.000000 | \n",
" 529.000000 | \n",
" 0.239547 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
" | 84 | \n",
" 0.001600 | \n",
" -1.000000 | \n",
" 2.828427 | \n",
" 368.500000 | \n",
" 151.000000 | \n",
" 586.000000 | \n",
" 0.500000 | \n",
" 151.000000 | \n",
" 151.000000 | \n",
" 151.000000 | \n",
" 1.608883 | \n",
" -0.500000 | \n",
" 2.121320 | \n",
" 0.000000 | \n",
" 1.414214 | \n",
" -0.500000 | \n",
" 0.707107 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Steps = 1 State = failed\n",
"def strategy(board):\n",
" # Helper: simulate a move, return new board and score\n",
" def simulate(board, dir):\n",
" n = len(board)\n",
" new = [[0]*n for _ in range(n)]\n",
" score = 0\n",
" for i in range(n):\n",
" # extract line\n",
" if dir == 'A':\n",
" line = [board[i][j] for j in range(n)]\n",
" rev = False\n",
" elif dir == 'D':\n",
" line = [board[i][j] for j in range(n-1, -1, -1)]\n",
" rev = True\n",
" elif dir == 'W':\n",
" line = [board[j][i] for j in range(n)]\n",
" rev = False\n",
" else: # 'S'\n",
" line = [board[j][i] for j in range(n-1, -1, -1)]\n",
" rev = True\n",
" # compress and merge\n",
" new_line = [x for x in line if x != 0]\n",
" merged = []\n",
" j = 0\n",
" while j < len(new_line):\n",
" if j + 1 < len(new_line) and new_line[j] == new_line[j+1]:\n",
" merged.append(new_line[j]*2)\n",
" score += new_line[j]*2\n",
" j += 2\n",
" else:\n",
" merged.append(new_line[j])\n",
" j += 1\n",
" # fill with zeros\n",
" merged += [0]*(n-len(merged))\n",
" # place back\n",
" if rev:\n",
" merged = merged[::-1]\n",
" if dir in ('A','D'):\n",
" for j in range(n):\n",
" new[i][j] = merged[j]\n",
" else:\n",
" for j in range(n):\n",
" new[j][i] = merged[j]\n",
" return new, score\n",
"\n",
" best, best_dir = 0, None\n",
" for dir in ('W','A','S','D'):\n",
" _, score = simulate(board, dir)\n",
" if score > best:\n",
" best, best_dir = score, dir\n",
" return best_dir # returns one of 'W','A','S','D'\n",
"┌───┬───┬───┬───┬───┬───┐\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"└───┴───┴───┴───┴───┴───┘\n",
"Unsloth: Will smartly offload gradients to save VRAM!\n",
"def strategy(board):\n",
" # helpers\n",
" def move(b, d):\n",
" n = len(b)\n",
" def compress(row):\n",
" new = [x for x in row if x!=0]\n",
" for i in range(len(new)-1):\n",
" if new[i]==new[i+1]:\n",
" new[i]*=2; new[i+1]=0\n",
" return [x for x in new if x!=0]+[0]*(n-len(new))\n",
" res=[[0]*n for _ in range(n)]\n",
" if d==\"W\":\n",
" for j in range(n):\n",
" col=[b[i][j] for i in range(n)]\n",
" col=compress(col)\n",
" for i in range(n):\n",
" res[i][j]=col[i]\n",
" elif d==\"S\":\n",
" for j in range(n):\n",
" col=[b[i][j] for i in range(n)][::-1]\n",
" col=compress(col)\n",
" col=col[::-1]\n",
" for i in range(n):\n",
" res[i][j]=col[i]\n",
" elif d==\"A\":\n",
" for i in range(n):\n",
" row=compress(b[i])\n",
" res[i]=row\n",
" elif d==\"D\":\n",
" for i in range(n):\n",
" row=compress(b[i][::-1])\n",
" row=row[::-1]\n",
" res[i]=row\n",
" return res\n",
"\n",
" def score(b):\n",
" return sum(sum(row) for row in b)\n",
"\n",
" moves=\"WASD\"\n",
" best=None; best_val=-1\n",
" for m in moves:\n",
" nb=move(board, m)\n",
" val=score(nb)\n",
" if val>best_val and any(nb[i][j]!=board[i][j] for i in range(len(nb)) for j in range(len(nb[0]))):\n",
" best_val=val; best=m\n",
" return best if best else \"W\"\n",
"Exception = list index out of range\n",
"Timeout\n",
"Steps = 475 State = failed\n",
"def strategy(board):\n",
" def move_possible(board, direction):\n",
" rows, cols = len(board), len(board[0])\n",
" if direction == 'W':\n",
" for j in range(cols):\n",
" for i in range(1, rows):\n",
" if board[i][j] != 0:\n",
" for k in range(i-1, -1, -1):\n",
" if board[k][j] == 0 or board[k][j] == board[i][j]:\n",
" return True\n",
" if board[k][j] != 0:\n",
" break\n",
" elif direction == 'S':\n",
" for j in range(cols):\n",
" for i in range(rows-2, -1, -1):\n",
" if board[i][j] != 0:\n",
" for k in range(i+1, rows):\n",
" if board[k][j] == 0 or board[k][j] == board[i][j]:\n",
" return True\n",
" if board[k][j] != 0:\n",
" break\n",
" elif direction == 'A':\n",
" for i in range(rows):\n",
" for j in range(1, cols):\n",
" if board[i][j] != 0:\n",
" for k in range(j-1, -1, -1):\n",
" if board[i][k] == 0 or board[i][k] == board[i][j]:\n",
" return True\n",
" if board[i][k] != 0:\n",
" break\n",
" elif direction == 'D':\n",
" for i in range(rows):\n",
" for j in range(cols-2, -1, -1):\n",
" if board[i][j] != 0:\n",
" for k in range(j+1, cols):\n",
" if board[i][k] == 0 or board[i][k] == board[i][j]:\n",
" return True\n",
" if board[i][k] != 0:\n",
" break\n",
" return False\n",
"\n",
" # Prefer moves that allow a merge as they increase score\n",
" for d in ('W', 'S', 'A', 'D'):\n",
" if move_possible(board, d):\n",
" return d\n",
" # If no merges are possible, pick any direction that moves tiles\n",
" for d in ('W', 'S', 'A', 'D'):\n",
" if any(board[i][j] != 0 for i in range(len(board)) for j in range(len(board[0]))):\n",
" return d\n",
" return 'W'\n",
"┌───┬───┬───┬───┬───┬───┐\n",
"│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;154m128\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;226m256\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;154m128\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\n",
"├───┼───┼───┼───┼───┼───┤\n",
"│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\n",
"└───┴───┴───┴───┴───┴───┘\n",
"Exception = '>' not supported between instances of 'tuple' and 'float'\n",
"def strategy(board):\n",
" import random, copy\n",
"\n",
" def rotate(b):\n",
" return [[b[3-j][i] for j in range(4)] for i in range(4)]\n",
"\n",
" def compress(b):\n",
" new = []\n",
" for row in b:\n",
" new_row = [x for x in row if x != 0]\n",
" new_row += [0]*(4-len(new_row))\n",
" new.append(new_row)\n",
" return new\n",
"\n",
" def merge(b):\n",
" for row in b:\n",
" for i in range(3):\n",
" if row[i]==row[i+1] and row[i]!=0:\n",
" row[i]*=2\n",
" row[i+1]=0\n",
"\n",
" def move(b, dir):\n",
" if dir==\"W\":\n",
" return merge(rotate(compress(rotate(b))))\n",
" if dir==\"S\":\n",
" return rotate(merge(compress(rotate(b))))\n",
" if dir==\"A\":\n",
" return merge(compress(b))\n",
" if dir==\"D\":\n",
" return rotate(merge(compress(rotate(b)))) # actually reverse\n",
"\n",
" best_score=0\n",
" best_move=None\n",
" for move_dir in \"WASD\":\n",
" new_board=move(copy.deepcopy(board), move_dir)\n",
" score=sum(sum(row) for row in new_board)\n",
" if score>best_score:\n",
" best_score=score\n",
" best_move=move_dir\n",
" return best_move\n",
"Exception = 'NoneType' object is not iterable\n",
"Exception = name 'n' is not defined\n",
"Timeout\n",
"Timeout\n",
"None\n",
"Timeout\n",
"def strategy(board):\n",
" # Prioritize merges, then favor left/up moves\n",
" rows, cols = len(board), len(board[0]) if board else 0\n",
"\n",
" # Helper to check if a move is possible\n",
" def can_move(direction):\n",
" if direction == 'W':\n",
" for c in range(cols):\n",
" for r in range(rows-1):\n",
" if board[r][c] == 0 or board[r][c] == board[r+1][c]:\n",
" return True\n",
" elif direction == 'A':\n",
" for r in range(rows):\n",
" for c in range(cols-1):\n",
" if board[r][c] == 0 or board[r][c] == board[r][c+1]:\n",
" return True\n",
" elif direction == 'S':\n",
" for c in range(cols):\n",
" for r in range(rows-1,0,-1):\n",
" if board[r][c] == 0 or board[r][c] == board[r-1][c]:\n",
" return True\n",
" elif direction == 'D':\n",
" for r in range(rows):\n",
" for c in range(cols-1,0,-1):\n",
" if board[r][c] == 0 or board[r][c] == board[r][c-1]:\n",
" return True\n",
" return False\n",
"\n",
" # Generate all moves\n",
" moves = []\n",
" for d in ['W', 'A', 'S', 'D']:\n",
" if can_move(d):\n",
" moves.append(d)\n",
"\n",
" # If multiple moves, pick one that maximizes the sum of merges\n",
" if not moves:\n",
" return 'W' # fallback\n",
" # Simple heuristic: prefer first move that allows a merge\n",
" return moves[0]\n",
"Timeout\n",
"Steps = 1512 State = failed\n",
"def strategy(board):\n",
" # helper to check possible merge in a row or column\n",
" def can_merge(lst):\n",
" for i in range(len(lst)-1):\n",
" if lst[i] > 0 and lst[i] == lst[i+1]:\n",
" return True\n",
" return False\n",
"\n",
" # try to move in a direction that creates a merge\n",
" for dir, delta in [(\"W\", (-1,0)), (\"A\", (0,-1)), (\"S\", (1,0)), (\"D\", (0,1))]:\n",
" merged = False\n",
" for i in range(len(board)):\n",
" for j in range(len(board[0])):\n",
" if board[i][j] > 0:\n",
" ni, nj = i + delta[0], j + delta[1]\n",
" if 0 <= ni < len(board) and 0 <= nj < len(board[0]):\n",
" if board[ni][nj] == 0:\n",
" return dir\n",
" if board[ni][nj] == board[i][j]:\n",
" return dir\n",
" # fallback: move down\n",
" return \"S\"\n",
"┌────┬────┬────┬────┬────┬────┐\n",
"│\u001b[38;5;214m 512\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;226m 256\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;154m 128\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;208m1024\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;226m 256\u001b[0m│\u001b[38;5;154m 128\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;226m 256\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;154m 128\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\n",
"└────┴────┴────┴────┴────┴────┘\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" # Simple greedy: choose direction that keeps tiles sorted in ascending order left-bottom\n",
" best = \" \"\n",
" best_val = -1\n",
" for d in \"WASD\":\n",
" # simulate move\n",
" b = [row[:] for row in board]\n",
" # merge function\n",
" def merge(row):\n",
" new = [x for x in row if x != 0]\n",
" res = []\n",
" i = 0\n",
" while i < len(new):\n",
" if i+1 < len(new) and new[i] == new[i+1]:\n",
" res.append(new[i]*2)\n",
" i += 2\n",
" else:\n",
" res.append(new[i])\n",
" i += 1\n",
" return res + [0]*(len(row)-len(res))\n",
" moved = False\n",
" if d == \"W\":\n",
" for col in range(4):\n",
" col_vals = [board[r][col] for r in range(4)]\n",
" merged = merge(col_vals)\n",
" for r in range(4):\n",
" b[r][col] = merged[r]\n",
" elif d == \"S\":\n",
" for col in range(4):\n",
" col_vals = [board[r][col] for r in range(4)][::-1]\n",
" merged = merge(col_vals)[::-1]\n",
" for r in range(4):\n",
" b[r][col] = merged[r]\n",
" elif d == \"A\":\n",
" for r in range(4):\n",
" row_vals = board[r]\n",
" merged = merge(row_vals)\n",
" b[r] = merged\n",
" elif d == \"D\":\n",
" for r in range(4):\n",
" row_vals = board[r][::-1]\n",
" merged = merge(row_vals)[::-1]\n",
" b[r] = merged\n",
" score = sum(filter(None, [x for row in b for x in row]))\n",
" if score > best_val:\n",
" best_val = score\n",
" best = d\n",
" return best\n",
"Timeout\n",
"Timeout\n",
"Exception = 'str' object is not callable\n",
"Timeout\n",
"def strategy(board):\n",
" # helper to rotate board\n",
" def rotate(b): return [list(col)[::-1] for col in zip(*b)]\n",
" # helper to move up\n",
" def move_up(b):\n",
" n=len(b)\n",
" new=[[] for _ in range(n)]\n",
" for j in range(n):\n",
" col=[b[i][j] for i in range(n) if b[i][j]!=0]\n",
" merged=[]\n",
" i=0\n",
" while i< len(col):\n",
" if i+1best_val:\n",
" best_val=val; best=dir\n",
" return best\n",
"Exception = list assignment index out of range\n",
"Timeout\n",
"Exception = list index out of range\n",
"def strategy(board):\n",
" import copy\n",
" moves = \"WASD\"\n",
" best = None\n",
" best_score = -1\n",
" for m in moves:\n",
" b = copy.deepcopy(board)\n",
" if m==\"W\":\n",
" for c in range(len(b)):\n",
" merged = []\n",
" for r in range(len(b)):\n",
" val = b[r][c]\n",
" if val!=0:\n",
" merged.append(val)\n",
" i=0\n",
" while i+1best_score:\n",
" best_score=score; best=m\n",
" return best\n",
"Timeout\n",
"Timeout\n",
"Exception = unsupported operand type(s) for -: 'range' and 'int'\n",
"def strategy(board):\n",
" # board is a 4x4 list of ints, 0 for empty\n",
" # Simple greedy: move that merges most tiles\n",
" moves = {}\n",
" dirs = {\"W\": (-1,0), \"A\": (0,-1), \"S\": (1,0), \"D\": (0,1)}\n",
" for d, (dr,dc) in dirs.items():\n",
" # simulate move\n",
" new_board = [row[:] for row in board]\n",
" merged = 0\n",
" for i in range(4):\n",
" for j in range(4):\n",
" if new_board[i][j]==0: continue\n",
" ni, nj = i+dr, j+dc\n",
" while 0<=ni<4 and 0<=nj<4 and new_board[ni][nj]==0:\n",
" ni+=dr; nj+=dc\n",
" if 0<=ni<4 and 0<=nj<4 and new_board[ni][nj]==new_board[i][j]:\n",
" merged+=1\n",
" moves[d]=merged\n",
" # choose direction with most merges, default W\n",
" best = max(moves, key=moves.get)\n",
" return best\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"Exception = list index out of range\n",
"def strategy(board):\n",
" moves = \"WASD\"\n",
" best = None\n",
" best_score = -1\n",
" for m in moves:\n",
" new_board = [row[:] for row in board]\n",
" if m == \"W\":\n",
" new_board = _move_up(new_board)\n",
" elif m == \"A\":\n",
" new_board = _move_left(new_board)\n",
" elif m == \"S\":\n",
" new_board = _move_down(new_board)\n",
" else: # \"D\"\n",
" new_board = _move_right(new_board)\n",
" score = sum(sum(row) for row in new_board)\n",
" if score > best_score:\n",
" best_score, best = score, m\n",
" return best\n",
"\n",
"def _compress(line):\n",
" nonzero = [x for x in line if x]\n",
" res = []\n",
" i = 0\n",
" while i < len(nonzero):\n",
" if i + 1 < len(nonzero) and nonzero[i] == nonzero[i+1]:\n",
" res.append(nonzero[i]*2)\n",
" i += 2\n",
" else:\n",
" res.append(nonzero[i])\n",
" i += 1\n",
" return res + [0]*(len(line)-len(res))\n",
"\n",
"def _move_up(b):\n",
" n = len(b)\n",
" res = [[0]*n for _ in range(n)]\n",
" for j in range(n):\n",
" col = [b[i][j] for i in range(n)]\n",
" col = _compress(col)\n",
" for i in range(n):\n",
" res[i][j] = col[i]\n",
" return res\n",
"\n",
"def _move_down(b):\n",
" n = len(b)\n",
" res = [[0]*n for _ in range(n)]\n",
" for j in range(n):\n",
" col = [b[i][j] for i in range(n)][::-1]\n",
" col = _compress(col)\n",
" for i in range(n):\n",
" res[n-1-i][j] = col[i]\n",
" return res\n",
"\n",
"def _move_left(b):\n",
" n = len(b)\n",
" res = [[0]*n for _ in range(n)]\n",
" for i in range(n):\n",
" row = _compress(b[i])\n",
" res[i] = row\n",
" return res\n",
"\n",
"def _move_right(b):\n",
" n = len(b)\n",
" res = [[0]*n for _ in range(n)]\n",
" for i in range(n):\n",
" row = _compress(b[i][::-1])[::-1]\n",
" res[i] = row\n",
" return res\n",
"Exception = 'int' object is not subscriptable\n",
"Timeout\n",
"def strategy(board):\n",
" # helper to apply a move and return new board\n",
" def move(b, dir):\n",
" n = len(b)\n",
" res = [[0]*n for _ in range(n)]\n",
" for x in range(n):\n",
" line = []\n",
" for y in range(n):\n",
" i,j = (y,x) if dir==\"D\" else (x,y)\n",
" if dir==\"A\": i=j\n",
" # skip for brevity\n",
"\n",
" # simplified heuristic: choose direction that increases sum of merged tiles\n",
" best, best_sum = None, -1\n",
" dirs = \"WASD\"\n",
" for d in dirs:\n",
" new = move(board, d)\n",
" merged = sum(c for r in new for c in r) - sum(c for r in board for c in r)\n",
" if merged > best_sum:\n",
" best_sum, best = merged, d\n",
" return best\n",
"Exception = 'NoneType' object is not iterable\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" import math\n",
" def score(b):\n",
" empty = sum(1 for r in b for v in r if v==0)\n",
" mx = max(max(row) for row in b)\n",
" return empty*10 + mx\n",
" best=None; best_score=-math.inf\n",
" for move in \"WASD\":\n",
" new=board.copy()\n",
" # simulate simple move logic\n",
" if move==\"W\":\n",
" for col in range(4):\n",
" col_vals=[r[col] for r in new if r[col]!=0]\n",
" for i,row in enumerate(col_vals):\n",
" new[i][col]=col_vals[i]\n",
" for i in range(i+1,4):\n",
" new[i][col]=0\n",
" elif move==\"S\":\n",
" for col in range(4):\n",
" col_vals=[r[col] for r in new if r[col]!=0]\n",
" for i,row in enumerate(reversed(col_vals)):\n",
" new[3-i][col]=col_vals[i]\n",
" for i in range(3-i+1,4):\n",
" new[i][col]=0\n",
" elif move==\"A\":\n",
" for row in range(4):\n",
" row_vals=[v for v in new[row] if v!=0]\n",
" for i,v in enumerate(row_vals):\n",
" new[row][i]=row_vals[i]\n",
" for i in range(i+1,4):\n",
" new[row][i]=0\n",
" elif move==\"D\":\n",
" for row in range(4):\n",
" row_vals=[v for v in new[row] if v!=0]\n",
" for i,v in enumerate(reversed(row_vals)):\n",
" new[row][3-i]=row_vals[i]\n",
" for i in range(3-i+1,4):\n",
" new[row][i]=0\n",
" sc=score(new)\n",
" if sc>best_score:\n",
" best_score=sc; best=move\n",
" return best\n",
"Exception = cannot access local variable 'i' where it is not associated with a value\n",
"Timeout\n",
"Exception = name 'merge' is not defined\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" # 4x4 board\n",
" moves = 'W A S D'.split()\n",
" best = None\n",
" best_score = -1\n",
" for m in moves:\n",
" b = [row[:] for row in board] # copy\n",
" for i in range(4):\n",
" line = b[i] if m in 'AD' else [row[i] for row in b]\n",
" merged = []\n",
" skip = False\n",
" for j, v in enumerate(line):\n",
" if v == 0: continue\n",
" if skip:\n",
" skip = False\n",
" continue\n",
" if j + 1 < len(line) and line[j+1] == v:\n",
" merged.append(v*2)\n",
" skip = True\n",
" else:\n",
" merged.append(v)\n",
" while len(merged) < 4:\n",
" merged.append(0)\n",
" if m in 'AD':\n",
" for k in range(4): b[i][k] = merged[k]\n",
" else:\n",
" for k in range(4): b[k][i] = merged[k]\n",
" score = sum(sum(row) for row in b)\n",
" if score > best_score:\n",
" best_score = score\n",
" best = m\n",
" return best\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" # board is a list of lists representing a 4x4 grid.\n",
" # possible moves\n",
" moves = ['W', 'A', 'S', 'D']\n",
" best = None\n",
" best_score = -1\n",
" \n",
" def score(b):\n",
" s = 0\n",
" for row in b:\n",
" for v in row:\n",
" s += v\n",
" return s\n",
" \n",
" for m in moves:\n",
" nb = [row[:] for row in board]\n",
" # simulate move m (very naive: just return new board if any merge)\n",
" merged = False\n",
" for i in range(4):\n",
" for j in range(4):\n",
" if nb[i][j] == 0: continue\n",
" for di, dj in ( (-1,0),(1,0),(0,-1),(0,1) ):\n",
" ni, nj = i+di, j+dj\n",
" if 0<=ni<4 and 0<=nj<4 and nb[ni][nj]==nb[i][j]:\n",
" nb[ni][nj] += nb[i][j]\n",
" nb[i][j] = 0\n",
" merged = True\n",
" if merged:\n",
" sc = score(nb)\n",
" if sc > best_score:\n",
" best_score, best = sc, m\n",
" return best if best is not None else moves[0]\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"Exception = cannot access local variable 'val' where it is not associated with a value\n",
"None\n",
"Timeout\n",
"Timeout\n",
"Exception = not enough values to unpack (expected 2, got 1)\n",
"def strategy(board):\n",
" # evaluate a move by the total sum after the move\n",
" def sim(b, m):\n",
" n = len(b)\n",
" b = [row[:] for row in b]\n",
" moved = False\n",
" if m == 'W':\n",
" for j in range(n):\n",
" col = [b[i][j] for i in range(n)]\n",
" col += [0]*(n-len(col))\n",
" newcol = []\n",
" i = 0\n",
" while i < n:\n",
" if col[i] == 0:\n",
" i += 1\n",
" continue\n",
" val = col[i]\n",
" i += 1\n",
" while i < n and col[i] == 0: i += 1\n",
" if i < n and col[i] == val:\n",
" val *= 2\n",
" i += 1\n",
" newcol.append(val)\n",
" for i in range(n):\n",
" b[i][j] = newcol[i] if i < len(newcol) else 0\n",
" moved = True\n",
" # other moves omitted for brevity \n",
" return b if moved else None\n",
"\n",
" best, best_val = None, -1\n",
" for m in \"WASD\":\n",
" r = sim(board, m)\n",
" if r:\n",
" val = sum(sum(row) for row in r)\n",
" if val > best_val:\n",
" best_val, best = val, m\n",
" return best if best else \"W\"\n",
"Timeout\n",
"Exception = list index out of range\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
"Timeout\n",
"Exception = strategy..rotate() takes 1 positional argument but 2 were given\n",
"def strategy(board):\n",
" # helper to simulate a move\n",
" def move(b, direction):\n",
" size = len(b)\n",
" new = [[0]*size for _ in range(size)]\n",
" for i in range(size):\n",
" if direction in ('A','D'):\n",
" line = b[i] if direction=='D' else b[i][::-1]\n",
" else:\n",
" line = [b[j][i] for j in range(size)]\n",
" if direction=='S': line = line[::-1]\n",
" merged = []\n",
" skip = False\n",
" for val in line:\n",
" if val==0: continue\n",
" if merged and merged[-1]==val and not skip:\n",
" merged[-1] += val\n",
" skip = True\n",
" else:\n",
" merged.append(val)\n",
" skip = False\n",
" for j,v in enumerate(merged):\n",
" new[i if direction=='A' else size-1-i][j if direction=='A' else size-1-j] = v\n",
" return new\n",
"\n",
" # evaluate each move\n",
" best = None\n",
" best_val = -1\n",
" for dirc in 'WASD':\n",
" new_board = move(board, dirc)\n",
" val = sum(sum(row) for row in new_board)\n",
" if val > best_val:\n",
" best_val = val\n",
" best = dirc\n",
" return best\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"None\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"None\n",
"Exception = unsupported operand type(s) for -: 'list' and 'int'\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" # Simple heuristic: move up unless a merge is possible in another direction\n",
" # Check if any pair can merge horizontally or vertically\n",
" for i in range(4):\n",
" for j in range(3):\n",
" if board[i][j] == board[i][j+1]:\n",
" return \"A\" # left\n",
" for i in range(3):\n",
" for j in range(4):\n",
" if board[i][j] == board[i+1][j]:\n",
" return \"W\" # up\n",
" return \"D\" # fallback\n",
"Timeout\n",
"Exception = list index out of range\n",
"def strategy(board):\n",
" def score_for(move):\n",
" B = [row[:] for row in board]\n",
" def slide(row):\n",
" new = [x for x in row if x != 0]\n",
" res = []\n",
" skip = False\n",
" for i, x in enumerate(new):\n",
" if skip:\n",
" skip = False\n",
" continue\n",
" if i+1 < len(new) and new[i] == new[i+1]:\n",
" res.append(x*2)\n",
" skip = True\n",
" else:\n",
" res.append(x)\n",
" return res + [0]*(len(row)-len(res))\n",
" if move=='W':\n",
" for i in range(len(B)):\n",
" B[i] = slide(B[i])\n",
" elif move=='S':\n",
" B = B[::-1]\n",
" for i in range(len(B)):\n",
" B[i] = slide(B[i])\n",
" B = B[::-1]\n",
" elif move=='A':\n",
" for row in B:\n",
" row[:] = slide(row)\n",
" elif move=='D':\n",
" for row in B:\n",
" row[:] = slide(row[::-1])[::-1]\n",
" empty = sum(cell==0 for r in B for cell in r)\n",
" return empty\n",
" best=None\n",
" for m in 'WASD':\n",
" if score_for(m)>best[1] if best else -1:\n",
" best=(m,score_for(m))\n",
" return best[0]\n",
"Timeout\n",
"Timeout\n",
"Exception = list assignment index out of range\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" '''\n",
" Returns the best next move for a 2048 game using a very small heuristic.\n",
" The heuristic looks at the free spaces after the move and chooses the\n",
" direction that tends to leave the most empty tiles.\n",
" '''\n",
" from functools import lru_cache\n",
"\n",
" # Flatten the board for easier hashing\n",
" flatten = tuple(tuple(row) for row in board)\n",
"\n",
" # Helper: simulate a move\n",
" def move(state, direction):\n",
" size = len(state)\n",
" new_state = []\n",
" for row in state:\n",
" merged = []\n",
" for d in row:\n",
" if d != 0:\n",
" merged.append(d)\n",
"\n",
" if direction in ('A', 'D'): # horizontal move\n",
" merged = merged[::-1] if direction == 'D' else merged\n",
" i = 0\n",
" while i < len(merged) - 1:\n",
" if merged[i] == merged[i + 1]:\n",
" merged[i] *= 2\n",
" merged.pop(i + 1)\n",
" i += 1\n",
" merged += [0] * (size - len(merged))\n",
" if direction == 'D':\n",
" merged = merged[::-1]\n",
" new_state.append(tuple(merged))\n",
" else: # vertical move\n",
" new_state.append(tuple(merged))\n",
" # For vertical moves, reconstruct column-wise\n",
" if direction in ('W', 'S'):\n",
" transposed = list(zip(*new_state))\n",
" new_state = []\n",
" for col in transposed:\n",
" merged = []\n",
" for d in col:\n",
" if d != 0:\n",
" merged.append(d)\n",
" merged = merged[::-1] if direction == 'S' else merged\n",
" i = 0\n",
" while i < len(merged) - 1:\n",
" if merged[i] == merged[i + 1]:\n",
" merged[i] *= 2\n",
" merged.pop(i + 1)\n",
" i += 1\n",
" merged += [0] * (size - len(merged))\n",
" if direction == 'S':\n",
" merged = merged[::-1]\n",
" new_state.append(tuple(merged))\n",
" new_state = [tuple(row) for row in zip(*new_state)]\n",
" return tuple(tuple(row) for row in new_state)\n",
"\n",
" # Count empty tiles\n",
" def empty_count(state):\n",
" return sum(1 for row in state for cell in row if cell == 0)\n",
"\n",
" best_move = None\n",
" best_empty = -1\n",
" for move in ['W', 'A', 'S', 'D']:\n",
" new_board = move(flatten, move)\n",
" e = empty_count(new_board)\n",
" if e > best_empty:\n",
" best_empty = e\n",
" best_move = move\n",
" return best_move\n",
"Exception = 'str' object is not callable\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" import copy\n",
" # Helper to apply a move and return new board\n",
" def move(board, dir):\n",
" size = len(board)\n",
" def compress(line):\n",
" new = [x for x in line if x>0]\n",
" res = []\n",
" i = 0\n",
" while i < len(new):\n",
" if i+1 < len(new) and new[i]==new[i+1]:\n",
" res.append(new[i]*2)\n",
" i += 2\n",
" else:\n",
" res.append(new[i])\n",
" i += 1\n",
" res += [0]*(size-len(res))\n",
" return res\n",
" if dir=='W':\n",
" new = [compress(col) for col in zip(*board)]\n",
" return [list(row) for row in zip(*new)]\n",
" if dir=='A':\n",
" return [compress(row) for row in board]\n",
" if dir=='S':\n",
" rev = [list(reversed(row)) for row in board]\n",
" new = [compress(row) for row in rev]\n",
" return [list(reversed(row)) for row in new]\n",
" if dir=='D':\n",
" rev = [list(reversed(row)) for row in board]\n",
" new = [compress(row) for row in rev]\n",
" return [list(row) for row in new]\n",
" best = None\n",
" best_score = -1\n",
" for d in ['W','A','S','D']:\n",
" newboard = move(board, d)\n",
" # score: sum of all tiles (higher better)\n",
" score = sum(sum(row) for row in newboard)\n",
" if score > best_score:\n",
" best_score, best = score, d\n",
" return best\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" # helper to simulate a move and compute score\n",
" def simulate(move):\n",
" n = len(board)\n",
" new_board = [[0]*n for _ in range(n)]\n",
" for i in range(n):\n",
" line = board[i] if move in \"WB\" else [row[i] for row in board]\n",
" if move in \"DS\": # reverse for down/right\n",
" line = line[::-1]\n",
" merged = []\n",
" skip = False\n",
" for v in line:\n",
" if v == 0: continue\n",
" if merged and merged[-1][0] == v and not skip:\n",
" merged[-1] = (merged[-1][0]*2, merged[-1][1]+1)\n",
" skip = True\n",
" else:\n",
" merged.append((v, 0))\n",
" skip = False\n",
" merged += [(0,0)]*(n-len(merged))\n",
" for idx, (v, _) in enumerate(merged):\n",
" new_board[i if move in \"WD\" else idx][idx if move in \"WD\" else i] = v\n",
" return sum(sum(row) for row in new_board)\n",
"\n",
" best_move = None\n",
" best_score = -1\n",
" for m in \"WASD\":\n",
" try:\n",
" score = simulate(m)\n",
" if score > best_score:\n",
" best_score = score\n",
" best_move = m\n",
" except:\n",
" continue\n",
" return best_move or \"W\"\n",
"Timeout\n",
"Timeout\n",
"Exception = name 'n' is not defined\n",
"def strategy(board):\n",
" import copy\n",
" moves = {'W': (-1,0), 'A': (0,-1), 'S': (1,0), 'D': (0,1)}\n",
" def move(b, dir):\n",
" size = len(b)\n",
" mx, my = moves[dir]\n",
" new = [[0]*size for _ in range(size)]\n",
" for r in range(size):\n",
" line = []\n",
" nr = r + mx\n",
" for c in range(size):\n",
" nc = c + my\n",
" if 0 <= nr < size and 0 <= nc < size:\n",
" line.append(b[nr][nc])\n",
" # compress\n",
" res=[]\n",
" i=0\n",
" while i < len(line):\n",
" if i+10:\n",
" s+=b[r][c]\n",
" return s\n",
" best=None\n",
" best_score=-1\n",
" for m in moves:\n",
" nb=move(board,m)\n",
" s=score(nb)\n",
" if s>best_score:\n",
" best_score=s; best=m\n",
" return best\n",
"Exception = list index out of range\n",
"Exception = 'NoneType' object is not subscriptable\n",
"Exception = name 'col_index' is not defined\n",
"def strategy(board):\n",
" import copy\n",
" moves = \"WASD\"\n",
" best, best_move = -1, \"W\"\n",
" for m in moves:\n",
" b = copy.deepcopy(board)\n",
" if m == \"W\":\n",
" for i in range(3,-1,-1):\n",
" for j in range(4):\n",
" if b[i][j] and b[i-1][j] and b[i][j]==b[i-1][j]:\n",
" b[i-1][j]*=2; b[i][j]=0\n",
" elif m == \"S\":\n",
" for i in range(4):\n",
" for j in range(4):\n",
" if i<3 and b[i][j] and b[i+1][j] and b[i][j]==b[i+1][j]:\n",
" b[i+1][j]*=2; b[i][j]=0\n",
" elif m == \"A\":\n",
" for i in range(4):\n",
" for j in range(4):\n",
" if j<3 and b[i][j] and b[i][j+1] and b[i][j]==b[i][j+1]:\n",
" b[i][j+1]*=2; b[i][j]=0\n",
" elif m == \"D\":\n",
" for i in range(4):\n",
" for j in range(3,-1,-1):\n",
" if j>0 and b[i][j] and b[i][j-1] and b[i][j]==b[i][j-1]:\n",
" b[i][j-1]*=2; b[i][j]=0\n",
" score = sum(sum(row) for row in b)\n",
" if score > best:\n",
" best, best_move = score, m\n",
" return best_move\n",
"Timeout\n",
"Steps = 1825 State = failed\n",
"def strategy(board):\n",
" size = len(board)\n",
" # Helper to compute score of moves\n",
" def score_move(d):\n",
" new_board = [row[:] for row in board]\n",
" moved = False\n",
" if d == \"W\":\n",
" for j in range(size):\n",
" col = [new_board[i][j] for i in range(size)]\n",
" merged = merge(col)\n",
" for i in range(size):\n",
" new_board[i][j] = merged[i]\n",
" if merged != col:\n",
" moved = True\n",
" elif d == \"S\":\n",
" for j in range(size):\n",
" col = [new_board[i][j] for i in range(size)][::-1]\n",
" merged = merge(col)[::-1]\n",
" for i in range(size):\n",
" new_board[i][j] = merged[i]\n",
" if merged[::-1] != col:\n",
" moved = True\n",
" elif d == \"A\":\n",
" for i in range(size):\n",
" row = new_board[i][:]\n",
" merged = merge(row)\n",
" new_board[i] = merged\n",
" if merged != row:\n",
" moved = True\n",
" elif d == \"D\":\n",
" for i in range(size):\n",
" row = new_board[i][::-1]\n",
" merged = merge(row)[::-1]\n",
" new_board[i] = merged\n",
" if merged[::-1] != row:\n",
" moved = True\n",
" return moved, new_board\n",
"\n",
" def merge(line):\n",
" filtered = [x for x in line if x != 0]\n",
" merged = []\n",
" i = 0\n",
" while i < len(filtered):\n",
" if i+1 < len(filtered) and filtered[i] == filtered[i+1]:\n",
" merged.append(filtered[i]*2)\n",
" i += 2\n",
" else:\n",
" merged.append(filtered[i])\n",
" i += 1\n",
" merged += [0]*(size-len(merged))\n",
" return merged\n",
"\n",
" # Evaluate each direction\n",
" best_score = -1\n",
" best_dir = \"W\"\n",
" for d in \"WASD\":\n",
" moved, new_board = score_move(d)\n",
" if not moved:\n",
" continue\n",
" # simple heuristic: sum of all tiles\n",
" score = sum(sum(row) for row in new_board)\n",
" if score > best_score:\n",
" best_score = score\n",
" best_dir = d\n",
" return best_dir\n",
"┌────┬────┬────┬────┬────┬────┐\n",
"│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;208m1024\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;208m1024\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;214m 512\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;214m 512\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;154m 128\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;226m 256\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\n",
"└────┴────┴────┴────┴────┴────┘\n",
"Timeout\n",
"def strategy(board):\n",
" # Evaluate score for each move and pick the one with maximal tile value\n",
" dirs = {\"W\": (-1,0), \"A\": (0,-1), \"S\": (1,0), \"D\": (0,1)}\n",
" best = None\n",
" best_score = -1\n",
" for d, (dx, dy) in dirs.items():\n",
" new_board = [[0]*4 for _ in range(4)]\n",
" moved = False\n",
" for i in range(4):\n",
" for j in range(4):\n",
" ni, nj = i+dx, j+dy\n",
" if 0 <= ni < 4 and 0 <= nj < 4:\n",
" new_board[ni][nj] = board[i][j]\n",
" if new_board[ni][nj] != board[i][j]:\n",
" moved = True\n",
" if not moved:\n",
" continue\n",
" score = sum([sum(row) for row in new_board])\n",
" if score > best_score:\n",
" best_score = score\n",
" best = d\n",
" return best if best is not None else \"W\"\n",
"Timeout\n",
"Timeout\n",
"Exception = 'list_reverseiterator' object is not subscriptable\n",
"Timeout\n",
"def strategy(board):\n",
" def score_row(row, dir):\n",
" if dir == 'L':\n",
" row = row[::-1]\n",
" merged = []\n",
" skip = False\n",
" for val in row:\n",
" if val == 0: continue\n",
" if skip:\n",
" skip = False\n",
" continue\n",
" if merged and merged[-1] == val:\n",
" merged[-1] *= 2\n",
" skip = True\n",
" else:\n",
" merged.append(val)\n",
" merged += [0]*(len(row)-len(merged))\n",
" if dir == 'L':\n",
" merged = merged[::-1]\n",
" return merged\n",
"\n",
" def move(board, action):\n",
" new_board = [row[:] for row in board]\n",
" if action in 'L':\n",
" for r in new_board:\n",
" new_row = score_row(r, 'L')\n",
" for i, val in enumerate(new_row):\n",
" r[i] = val\n",
" elif action in 'R':\n",
" for r in new_board:\n",
" new_row = score_row(r, 'R')\n",
" for i, val in enumerate(new_row):\n",
" r[i] = val\n",
" elif action in 'U':\n",
" for c in range(4):\n",
" col = [new_board[r][c] for r in range(4)]\n",
" new_col = score_row(col, 'L')\n",
" for r in range(4):\n",
" new_board[r][c] = new_col[r]\n",
" elif action in 'D':\n",
" for c in range(4):\n",
" col = [new_board[r][c] for r in range(4)]\n",
" new_col = score_row(col, 'R')\n",
" for r in range(4):\n",
" new_board[r][c] = new_col[r]\n",
" return new_board\n",
"\n",
" def empty(board):\n",
" return [(r, c) for r in range(4) for c in range(4) if board[r][c] == 0]\n",
"\n",
" actions = 'WASD'\n",
" best = None\n",
" best_score = -1\n",
" for a in actions:\n",
" new = move(board, a)\n",
" empties = len(empty(new))\n",
" merged = sum(1 for r in new for val in r if val >0)\n",
" score = empties + merged\n",
" if score>best_score:\n",
" best_score = score\n",
" best = a\n",
" return best\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" # choose a move that keeps more tiles unchanged\n",
" moves = ['W','A','S','D']\n",
" best = moves[0]; best_score = -1\n",
" for m in moves:\n",
" new = board_state_after(board, m)\n",
" if new == board:\n",
" continue\n",
" score = score_board(new)\n",
" if score > best_score:\n",
" best_score = score; best = m\n",
" return best\n",
"def board_state_after(board, move):\n",
" # simulate move on a copy of the board\n",
" from copy import deepcopy\n",
" b = deepcopy(board)\n",
" n = len(b)\n",
" # simple implementation of move logic\n",
" def compress(line):\n",
" new = [x for x in line if x!=0]\n",
" res = []\n",
" i=0\n",
" while i < len(new):\n",
" if i+1 best_score:\n",
" best_score = score\n",
" best = dir\n",
" return best\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"def strategy(board):\n",
" def move(board, dir):\n",
" import copy\n",
" n=len(board)\n",
" new=[row[:] for row in board]\n",
" if dir=='W':\n",
" for j in range(n):\n",
" col=[new[i][j] for i in range(n)]\n",
" newcol=compress(col)\n",
" for i in range(n): new[i][j]=newcol[i]\n",
" elif dir=='S':\n",
" for j in range(n):\n",
" col=[new[i][j] for i in range(n)][::-1]\n",
" newcol=compress(col)[::-1]\n",
" for i in range(n): new[i][j]=newcol[i]\n",
" elif dir=='A':\n",
" for i in range(n):\n",
" new[i]=compress(new[i])\n",
" elif dir=='D':\n",
" for i in range(n):\n",
" new[i]=compress(new[i])[::-1][::-1]\n",
" return new\n",
" def compress(line):\n",
" filtered=[v for v in line if v>0]\n",
" res=[]\n",
" i=0\n",
" while ibest_score:\n",
" best_score=score; best=d\n",
" return best\n",
"Timeout\n",
"Steps = 1264 State = success\n",
"def strategy(board):\n",
" # board is a 4x4 list of lists\n",
" import random\n",
" \n",
" # Directions with priority: diagonal corners\n",
" dirs = ['W', 'A', 'S', 'D']\n",
" for d in dirs:\n",
" new_board = [row[:] for row in board]\n",
" if d == 'W':\n",
" for j in range(4):\n",
" merged = False\n",
" for i in range(1, 4):\n",
" if new_board[i][j] == new_board[i-1][j] and not merged:\n",
" new_board[i-1][j] += new_board[i][j]\n",
" new_board[i][j] = 0\n",
" merged = True\n",
" elif d == 'S':\n",
" for j in range(4):\n",
" merged = False\n",
" for i in range(2, -1, -1):\n",
" if new_board[i][j] == new_board[i+1][j] and not merged:\n",
" new_board[i+1][j] += new_board[i][j]\n",
" new_board[i][j] = 0\n",
" merged = True\n",
" elif d == 'A':\n",
" for i in range(4):\n",
" merged = False\n",
" for j in range(1, 4):\n",
" if new_board[i][j] == new_board[i][j-1] and not merged:\n",
" new_board[i][j-1] += new_board[i][j]\n",
" new_board[i][j] = 0\n",
" merged = True\n",
" elif d == 'D':\n",
" for i in range(4):\n",
" merged = False\n",
" for j in range(2, -1, -1):\n",
" if new_board[i][j] == new_board[i][j+1] and not merged:\n",
" new_board[i][j+1] += new_board[i][j]\n",
" new_board[i][j] = 0\n",
" merged = True\n",
" # measure score: number of non-zero tiles\n",
" score = sum(1 for r in new_board for v in r if v != 0)\n",
" # choose first direction that reduces empty tiles\n",
" if score > sum(1 for r in board for v in r if v != 0):\n",
" return d\n",
" return random.choice(dirs)\n",
"┌────┬────┬────┬────┬────┬────┐\n",
"│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;239m .\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;118m 64\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;49m 8\u001b[0m│\u001b[38;5;239m .\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;46m 32\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\u001b[38;5;226m 256\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\n",
"├────┼────┼────┼────┼────┼────┤\n",
"│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;47m 16\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;196m2048\u001b[0m│\u001b[38;5;51m 4\u001b[0m│\u001b[38;5;45m 2\u001b[0m│\n",
"└────┴────┴────┴────┴────┴────┘\n",
"Exception = '>' not supported between instances of 'int' and 'str'\n",
"Exception = cannot pickle 'generator' object\n",
"Timeout\n",
"def strategy(board):\n",
" def move(board, direction):\n",
" size = len(board)\n",
" def compress(line):\n",
" new = [x for x in line if x>0]\n",
" merged = []\n",
" i=0\n",
" while i < len(new):\n",
" if i+1 < len(new) and new[i]==new[i+1]:\n",
" merged.append(new[i]*2)\n",
" i+=2\n",
" else:\n",
" merged.append(new[i])\n",
" i+=1\n",
" return merged+[0]*(size-len(merged))\n",
" new_board=[[0]*size for _ in range(size)]\n",
" if direction=='W':\n",
" for j in range(size):\n",
" col=[board[i][j] for i in range(size)]\n",
" col=compress(col)\n",
" for i in range(size):\n",
" new_board[i][j]=col[i]\n",
" elif direction=='S':\n",
" for j in range(size):\n",
" col=[board[i][j] for i in range(size)][::-1]\n",
" col=compress(col)[::-1]\n",
" for i in range(size):\n",
" new_board[i][j]=col[i]\n",
" elif direction=='A':\n",
" for i in range(size):\n",
" row=compress(board[i])\n",
" new_board[i]=row\n",
" elif direction=='D':\n",
" for i in range(size):\n",
" row=compress(board[i][::-1])[::-1]\n",
" new_board[i]=row\n",
" return new_board\n",
"\n",
" def score(b):\n",
" return sum(sum(1 for x in row if x>0) for row in b)\n",
"\n",
" best=None\n",
" bestScore=-1\n",
" for d in \"WASD\":\n",
" nb=move(board,d)\n",
" s=score(nb)\n",
" if s>bestScore:\n",
" bestScore=s\n",
" best=d\n",
" return best\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"Timeout\n",
"None\n",
"Timeout\n",
"Exception = list assignment index out of range\n",
"Timeout\n",
"None\n",
"Timeout\n"
]
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tlaUdxC_VHpz"
},
"source": [
"\n",
"# Inference\n",
"Now let's try the model we just trained!"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8BZZHOKiF9Ct",
"outputId": "4989f8d9-d024-462e-c732-b7734676791a"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\n",
"Knowledge cutoff: 2024-06\n",
"Current date: 2025-10-06\n",
"\n",
"Reasoning: low\n",
"\n",
"# Valid channels: analysis, commentary, final. Channel must be included for every message.\n",
"Calls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>user<|message|>Create a new short 2048 strategy using only native Python code.\n",
"You are given a list of list of numbers for the current board state.\n",
"Output one action for \"W\", \"A\", \"S\", \"D\" on what is the optimal next step.\n",
"Output your new short function in backticks using the format below:\n",
"```python\n",
"def strategy(board):\n",
" return \"W\" # Example\n",
"```\n",
"All helper functions should be inside def strategy. Only output the short function `strategy`.<|end|><|start|>assistant<|channel|>analysis<|message|>We need a short strategy that picks best move using basic heuristics. Use simple scoring: prefer moving to increase score: choose move that gives largest tile sum after simulation. Write quick function.<|end|><|start|>assistant<|channel|>final<|message|>```python\n",
"def strategy(board):\n",
" import itertools\n",
" dirs = [\"W\",\"A\",\"S\",\"D\"]\n",
" best_dir = None\n",
" best_val = -1\n",
" for d in dirs:\n",
" new_board = [row[:] for row in board]\n",
" moved = False\n",
" if d == \"W\":\n",
" for col in range(4):\n",
" temp = [new_board[r][col] for r in range(4) if new_board[r][col]]\n",
" merged = []\n",
" i=0\n",
" while i=0:\n",
" if i-1>=0 and temp[i]==temp[i-1]:\n",
" merged.insert(0,temp[i]*2); i-=2\n",
" else:\n",
" merged.insert(0,temp[i]); i-=1\n",
" for r in range(4):\n",
" new_board[r][col]=merged[r] if r=0:\n",
" if i-1>=0 and temp[i]==temp[i-1]:\n",
" merged.insert(0,temp[i]*2); i-=2\n",
" else:\n",
" merged.insert(0,temp[i]); i-=1\n",
" for c in range(4):\n",
" new_board[r][c]=merged[c] if cbest_val:\n",
" best_val, best_dir = val, d\n",
" return best_dir\n",
"```<|return|>\n"
]
}
],
"source": [
"text = tokenizer.apply_chat_template(\n",
" [{\"role\": \"user\", \"content\": prompt}],\n",
" tokenize = False,\n",
" add_generation_prompt = True,\n",
" reasoning_effort = \"low\",\n",
")\n",
"\n",
"from transformers import TextStreamer\n",
"_ = model.generate(\n",
" **tokenizer(text, return_tensors = \"pt\").to(\"cuda\"),\n",
" temperature = 1.0,\n",
" max_new_tokens = 1024,\n",
" streamer = TextStreamer(tokenizer, skip_prompt = False),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-NUEmHFSYNTp"
},
"source": [
"\n",
"### Saving to float16 or `MXFP4`\n",
"\n",
"We also support saving to `float16` directly. Select `merged_16bit` for float16 or `mxfp4` for MXFP4 (OpenAI's gpt-oss native precision). We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"id": "NjXGTkp7YNtB"
},
"outputs": [],
"source": [
"# Merge and push to hub in mxfp4 4bit format\n",
"if False:\n",
" model.save_pretrained_merged(\"finetuned_model\", tokenizer, save_method = \"mxfp4\")\n",
"if False:\n",
" model.push_to_hub_merged(\"repo_id/repo_name\", tokenizer, token = \"hf...\", save_method = \"mxfp4\")\n",
"\n",
"# Merge and push to hub in 16bit\n",
"if False:\n",
" model.save_pretrained_merged(\"finetuned_model\", tokenizer, save_method = \"merged_16bit\")\n",
"if False: # Pushing to HF Hub\n",
" model.push_to_hub_merged(\"hf/gpt-oss-finetune\", tokenizer, save_method = \"merged_16bit\", token = \"\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V15Yhj1V9lwG"
},
"source": [
"# And we're done!\n",
"Congratulations you just learned how to do reinforcement learning with gpt-oss! There were some advanced topics explained in this notebook - to learn more about gpt-oss and RL, there are more docs in Unsloth's [Reinforcement Learning Guide with gpt-oss](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"02d120e49f2c4f95a6090b1d8d521767": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_dbf5ed93dac646ed979fa7a8c569dfe3",
"placeholder": "",
"style": "IPY_MODEL_4db5ee5b7b674abba75fbce264e6dfa3",
"value": " 165/165 [00:00<00:00, 17.9kB/s]"
}
},
"04d39c4dda9f4a1bb01b8d6320032372": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"06ab9eaa6f0f48c4b68cff1ca4b9f2fa": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"07f0420c4dfa477caccd7ae96551c2e4": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ad75f887a140416abfca615b2fc3c385",
"max": 3996690997,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_dee02a37a6f44f168546ee0077dc20d1",
"value": 3996690997
}
},
"0ac4d8e674804ad6bdc5f2d62f2e0d33": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_7bfcd9acf29646db8b6123708d1ffe27",
"IPY_MODEL_5e88d6515f16475fb72d7c153422b591",
"IPY_MODEL_5e5b77dd649547f896ab306fccc94a4e"
],
"layout": "IPY_MODEL_a843fa23e6c94fb486bff8764574fdc5"
}
},
"0c0c96eeac664f339aa4511bf47087e2": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_18451e19df5449b1853b5e13dacd19c5",
"IPY_MODEL_d864d29d02c54ecfaedd7b866a6df8c2",
"IPY_MODEL_7875163297284832a35aca84cbb105ce"
],
"layout": "IPY_MODEL_d42d8228ea1247a1a81bb99b18c4640c"
}
},
"0f99489932aa409b94ba34764aff19b0": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"1183d3f2ad3c4fb0af1d925b5f9e3efe": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_9cc51d8029eb4217bc37daa918649692",
"IPY_MODEL_41f13d2f023e405180689e03bc2c32a1",
"IPY_MODEL_247484c0bf5945bcb4627b48928366c8"
],
"layout": "IPY_MODEL_14c0f20a9ab341ee966fe77815099ff0"
}
},
"147743757c804b85af2ef194f5f84e6a": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"14c0f20a9ab341ee966fe77815099ff0": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"152d7bf2a74f400db3d3ecaa719ef8d1": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"18451e19df5449b1853b5e13dacd19c5": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_bcda4c9a48e943a6a0ef812fcd64a6db",
"placeholder": "",
"style": "IPY_MODEL_61e491b843c347b6b2a9948de7caf01d",
"value": "tokenizer_config.json: "
}
},
"1c96edb2f7c948b9968b1239982af942": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ee23056662ad4b719b65005d776e0e72",
"placeholder": "",
"style": "IPY_MODEL_87765ca0996b403dbe29deef48d548bf",
"value": " 4.00G/4.00G [01:42<00:00, 117MB/s]"
}
},
"219ca32ab51e4b4385b2c1026a78503a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_6c2ccfe3363b40b58fc26ea164d4ead4",
"IPY_MODEL_07f0420c4dfa477caccd7ae96551c2e4",
"IPY_MODEL_1c96edb2f7c948b9968b1239982af942"
],
"layout": "IPY_MODEL_d93be4994f104b6e99d89a9e73cd6abd"
}
},
"245590db7d374515a428ff4abbd25588": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"247484c0bf5945bcb4627b48928366c8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_cef064f1c55f41bf957fc4623260fdb4",
"placeholder": "",
"style": "IPY_MODEL_37cbe8800af04a42a0355922969b6393",
"value": " 4/4 [01:00<00:00, 13.06s/it]"
}
},
"263b7dc0b3fd465fac89b9266b19d526": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_147743757c804b85af2ef194f5f84e6a",
"max": 4,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_2820e352ab004e818949acc31eb3888d",
"value": 4
}
},
"2820e352ab004e818949acc31eb3888d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"2a6aa92676c74509b58373ca604c5b3b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2a6f43b64d164636a2d9708f0190f21b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2c40c6b846924200b29616a590af1672": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_06ab9eaa6f0f48c4b68cff1ca4b9f2fa",
"placeholder": "",
"style": "IPY_MODEL_d98c2b1e979b4929891a8ee0c11f55df",
"value": "model.safetensors.index.json: "
}
},
"2fa84865e9f14c1491402ef81517b4bd": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"32d6af64f2464cfb965671f2692b4e15": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"34a9e38b0b454a69a067d1ddadec7626": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_9c4d6839934b4b13952a850d2084d498",
"placeholder": "",
"style": "IPY_MODEL_c6a1decbc0e7421db622033214913cb9",
"value": "Fetching 4 files: 100%"
}
},
"350f29f737534bfba4258bc31ec274a2": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"36676899a61f4be4b631f6271f6ecec9": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"37cbe8800af04a42a0355922969b6393": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"3f9b801b52da4eb79f730d87bea5c338": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_b66c6ded549d4db8a2e5ea8e5016615c",
"IPY_MODEL_43da5073c3ad4e98a3ade17a0bb3b93d",
"IPY_MODEL_40365e2c9fef49148e4c93592d458afc"
],
"layout": "IPY_MODEL_7e9d5212fc7844f286e14b70cbf0bc7a"
}
},
"40138ff29073407abb95f793509fc320": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"40365e2c9fef49148e4c93592d458afc": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2a6f43b64d164636a2d9708f0190f21b",
"placeholder": "",
"style": "IPY_MODEL_65c62d2198e64ee4a9e6547c2733135a",
"value": " 1.16G/1.16G [00:25<00:00, 39.8MB/s]"
}
},
"41f13d2f023e405180689e03bc2c32a1": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_36676899a61f4be4b631f6271f6ecec9",
"max": 4,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_77ecad9f150c430fa85f5833d97c42df",
"value": 4
}
},
"43da5073c3ad4e98a3ade17a0bb3b93d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4513a73fa95b41b5b6edadc9143ba9c1",
"max": 1158267008,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_792d75a7d18945e7972826ac5b2ac386",
"value": 1158267008
}
},
"4513a73fa95b41b5b6edadc9143ba9c1": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"48741bbdeccb459aa4eea9c61339764b": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"4b9b3fe8dc764eedb9e18f166fe2f548": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_87a808c4d4f54f719adcd29de7206e1b",
"placeholder": "",
"style": "IPY_MODEL_5f0b2a0e1953406b88af2c884904e2da",
"value": "model-00003-of-00004.safetensors: 100%"
}
},
"4cb119127b404f46a53012c62d004e28": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4d67b10ec7794170addb4e968e20f170": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4da21f53bf7f4e2d8132eb43e6ecc739": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4db5ee5b7b674abba75fbce264e6dfa3": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"4fbc4cfe529d471ba85f3ae8e53b28d6": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_a0d0fedc5bec4f5b943fddf9a954fbdf",
"IPY_MODEL_cab602573c6940919f93e59fe6f4838d",
"IPY_MODEL_51b8f4ce40f94ac39cf44d98f1522ec7"
],
"layout": "IPY_MODEL_32d6af64f2464cfb965671f2692b4e15"
}
},
"51aaa109480d4ae6bd419aea689d22ee": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"51b8f4ce40f94ac39cf44d98f1522ec7": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_60ceb890b5644493a8886d91b9dac461",
"placeholder": "",
"style": "IPY_MODEL_40138ff29073407abb95f793509fc320",
"value": " 446/446 [00:00<00:00, 50.5kB/s]"
}
},
"55ac5c2a82ee48fe988e1e4f26c168b0": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"5657a84bf4b74710b2de1a54f9236e39": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"596c2a62a635469eb74233ce00586a6f": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"59e46bbe96df4b88ad31c09096ce0e0a": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"5a59fb5f7acf4213847c985e66c9ee3c": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_81a728910a2341a785a6f252bbb371f7",
"placeholder": "",
"style": "IPY_MODEL_69a8d50f11244ba688c183d14d2395ec",
"value": "generation_config.json: 100%"
}
},
"5b7af68130f04a63ad3efa3d9f602ebe": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_80fa3aef5e2040d9904c6b87b7214ca0",
"placeholder": "",
"style": "IPY_MODEL_0f99489932aa409b94ba34764aff19b0",
"value": " 4/4 [01:42<00:00, 42.23s/it]"
}
},
"5e5b77dd649547f896ab306fccc94a4e": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_59e46bbe96df4b88ad31c09096ce0e0a",
"placeholder": "",
"style": "IPY_MODEL_8f5c7b88a2cc4b5abb0814c814833349",
"value": " 15.1k/? [00:00<00:00, 1.37MB/s]"
}
},
"5e88d6515f16475fb72d7c153422b591": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_923653dfe90e475a9efa44baf98ba9a0",
"max": 1,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_62600092f8cc43f493b86b0169f67be1",
"value": 1
}
},
"5ebe7b4e4ed24c53b783ee46377c682d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_51aaa109480d4ae6bd419aea689d22ee",
"max": 3998751275,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_acf4e50a248342f68d26daef21baa419",
"value": 3998751275
}
},
"5f0b2a0e1953406b88af2c884904e2da": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"60ceb890b5644493a8886d91b9dac461": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"614c5332c7d045109102a329e7f69dfd": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"61e491b843c347b6b2a9948de7caf01d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"62600092f8cc43f493b86b0169f67be1": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"65c62d2198e64ee4a9e6547c2733135a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"68ea891644ca4753a8e1bf278ff47e84": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"69a8d50f11244ba688c183d14d2395ec": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"6a47e60b10a6481b94aee021c8dbc7ba": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"6ab4e5676ad84807a126fffa99f7a0d4": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_e61ef80398444c13bf7cd20ef21a5057",
"IPY_MODEL_5ebe7b4e4ed24c53b783ee46377c682d",
"IPY_MODEL_e0fdef0087bc4a91a11932a2d933c001"
],
"layout": "IPY_MODEL_596c2a62a635469eb74233ce00586a6f"
}
},
"6c2ccfe3363b40b58fc26ea164d4ead4": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4da21f53bf7f4e2d8132eb43e6ecc739",
"placeholder": "",
"style": "IPY_MODEL_735f70fac43449e3974de1b783d56d33",
"value": "model-00002-of-00004.safetensors: 100%"
}
},
"735f70fac43449e3974de1b783d56d33": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"749e8407a901483c8b513a2fb71596c8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ef01b874478b4bb497d31d2f8dd6145a",
"max": 1,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_d50ea8cded9848ffa18be1ae6a2559df",
"value": 1
}
},
"751a46fbb8e24efabfb381a85c90fbe8": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"77204d81ff8f4ee585361a503fa647dc": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"77d34c0f1de548b4872208a063bb5017": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"77ecad9f150c430fa85f5833d97c42df": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"7841bc90b6a74120ab3e603c76332a01": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"7875163297284832a35aca84cbb105ce": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ba94310dc12a4a258205b14901ad3f94",
"placeholder": "",
"style": "IPY_MODEL_a93210a691414502ba3c2dff03ffb4ce",
"value": " 22.8k/? [00:00<00:00, 1.66MB/s]"
}
},
"792d75a7d18945e7972826ac5b2ac386": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"7baca79d720c40b5a923b9717e28c982": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ffabf89ecd9d48a5a3fc2a1c855ce080",
"placeholder": "",
"style": "IPY_MODEL_614c5332c7d045109102a329e7f69dfd",
"value": " 1.19M/? [00:00<00:00, 81.8MB/s]"
}
},
"7bd5d1beeb0e49e293d9f6b91bb6d7fb": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"7bfcd9acf29646db8b6123708d1ffe27": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_fd0ac7ed3d3146ec85913f4e05c4a2f6",
"placeholder": "",
"style": "IPY_MODEL_77204d81ff8f4ee585361a503fa647dc",
"value": "chat_template.jinja: "
}
},
"7d3379cbd27a4218a9d84c5a12f3bb88": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7e9d5212fc7844f286e14b70cbf0bc7a": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"80fa3aef5e2040d9904c6b87b7214ca0": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"81a728910a2341a785a6f252bbb371f7": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"84d27c45065e426badbfcfcdc8ff16b6": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4d67b10ec7794170addb4e968e20f170",
"max": 27868174,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_55ac5c2a82ee48fe988e1e4f26c168b0",
"value": 27868174
}
},
"87765ca0996b403dbe29deef48d548bf": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"87a808c4d4f54f719adcd29de7206e1b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"8c7c6bb04a3f4a1494b34529f95a195c": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"8db5e86577744ff1a39c8e198eee5dd3": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_4b9b3fe8dc764eedb9e18f166fe2f548",
"IPY_MODEL_cca95e973bc445d3811335debf7c446e",
"IPY_MODEL_e507a46b4c754d9a8aede2aac0d203bc"
],
"layout": "IPY_MODEL_751a46fbb8e24efabfb381a85c90fbe8"
}
},
"8f1e6c36b84c4115a671dcb9ade41c8b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"8f5c7b88a2cc4b5abb0814c814833349": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"923653dfe90e475a9efa44baf98ba9a0": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": "20px"
}
},
"9a079a30b4ae4bbc80122faf83e0ad59": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9beac0680e3049dfafcb6ec185fd2265": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"9c4d6839934b4b13952a850d2084d498": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9cc51d8029eb4217bc37daa918649692": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_a219f3b89a34443abe612846676f9356",
"placeholder": "",
"style": "IPY_MODEL_152d7bf2a74f400db3d3ecaa719ef8d1",
"value": "Loading checkpoint shards: 100%"
}
},
"a0d0fedc5bec4f5b943fddf9a954fbdf": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e1e77d98b01f4376a6c075975c27571e",
"placeholder": "",
"style": "IPY_MODEL_6a47e60b10a6481b94aee021c8dbc7ba",
"value": "special_tokens_map.json: 100%"
}
},
"a219f3b89a34443abe612846676f9356": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a843fa23e6c94fb486bff8764574fdc5": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a93210a691414502ba3c2dff03ffb4ce": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"abe2b0a2913d4633943f44333ae799f8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_2c40c6b846924200b29616a590af1672",
"IPY_MODEL_749e8407a901483c8b513a2fb71596c8",
"IPY_MODEL_7baca79d720c40b5a923b9717e28c982"
],
"layout": "IPY_MODEL_68ea891644ca4753a8e1bf278ff47e84"
}
},
"acda8e7582934fecbbf854e66e23f698": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"acf4e50a248342f68d26daef21baa419": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"ad75f887a140416abfca615b2fc3c385": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"ae6d42fb84fc4984af1d4430acdcd3c9": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_350f29f737534bfba4258bc31ec274a2",
"max": 165,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_9beac0680e3049dfafcb6ec185fd2265",
"value": 165
}
},
"b07acf871a0a46f1889bfb439d13752b": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"b66c6ded549d4db8a2e5ea8e5016615c": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_77d34c0f1de548b4872208a063bb5017",
"placeholder": "",
"style": "IPY_MODEL_bf96e8666c224c26b0a01451d08e907a",
"value": "model-00004-of-00004.safetensors: 100%"
}
},
"ba94310dc12a4a258205b14901ad3f94": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"bcda4c9a48e943a6a0ef812fcd64a6db": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"bf96e8666c224c26b0a01451d08e907a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"c6a1decbc0e7421db622033214913cb9": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"cab602573c6940919f93e59fe6f4838d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_5657a84bf4b74710b2de1a54f9236e39",
"max": 446,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_7bd5d1beeb0e49e293d9f6b91bb6d7fb",
"value": 446
}
},
"caf742160db041a1b6c2cfdf78f2dc9a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_34a9e38b0b454a69a067d1ddadec7626",
"IPY_MODEL_263b7dc0b3fd465fac89b9266b19d526",
"IPY_MODEL_5b7af68130f04a63ad3efa3d9f602ebe"
],
"layout": "IPY_MODEL_2a6aa92676c74509b58373ca604c5b3b"
}
},
"cca95e973bc445d3811335debf7c446e": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2fa84865e9f14c1491402ef81517b4bd",
"max": 3372033380,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_245590db7d374515a428ff4abbd25588",
"value": 3372033380
}
},
"cef064f1c55f41bf957fc4623260fdb4": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d42d8228ea1247a1a81bb99b18c4640c": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d50ea8cded9848ffa18be1ae6a2559df": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"d864d29d02c54ecfaedd7b866a6df8c2": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_dee07d33b8de4c3b847fcff670e68102",
"max": 1,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_b07acf871a0a46f1889bfb439d13752b",
"value": 1
}
},
"d9020a2a2c8440db81d2cfdf0289b667": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d93be4994f104b6e99d89a9e73cd6abd": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d98c2b1e979b4929891a8ee0c11f55df": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"da4324e287e64e5ba98fc110693066df": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"dbf5ed93dac646ed979fa7a8c569dfe3": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"dbfeea8ee2374b8c8fa70431c35f281f": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_d9020a2a2c8440db81d2cfdf0289b667",
"placeholder": "",
"style": "IPY_MODEL_04d39c4dda9f4a1bb01b8d6320032372",
"value": "tokenizer.json: 100%"
}
},
"dee02a37a6f44f168546ee0077dc20d1": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"dee07d33b8de4c3b847fcff670e68102": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": "20px"
}
},
"e0fdef0087bc4a91a11932a2d933c001": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_7d3379cbd27a4218a9d84c5a12f3bb88",
"placeholder": "",
"style": "IPY_MODEL_7841bc90b6a74120ab3e603c76332a01",
"value": " 4.00G/4.00G [01:41<00:00, 60.6MB/s]"
}
},
"e1e77d98b01f4376a6c075975c27571e": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e2973e6c02834a7c9f2f6ce5755f35f0": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e507a46b4c754d9a8aede2aac0d203bc": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e2973e6c02834a7c9f2f6ce5755f35f0",
"placeholder": "",
"style": "IPY_MODEL_48741bbdeccb459aa4eea9c61339764b",
"value": " 3.37G/3.37G [01:40<00:00, 32.0MB/s]"
}
},
"e61ef80398444c13bf7cd20ef21a5057": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_da4324e287e64e5ba98fc110693066df",
"placeholder": "",
"style": "IPY_MODEL_8c7c6bb04a3f4a1494b34529f95a195c",
"value": "model-00001-of-00004.safetensors: 100%"
}
},
"ee23056662ad4b719b65005d776e0e72": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"ef01b874478b4bb497d31d2f8dd6145a": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": "20px"
}
},
"f8dacdab001d4db0b6b3776ac7d3634a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_5a59fb5f7acf4213847c985e66c9ee3c",
"IPY_MODEL_ae6d42fb84fc4984af1d4430acdcd3c9",
"IPY_MODEL_02d120e49f2c4f95a6090b1d8d521767"
],
"layout": "IPY_MODEL_8f1e6c36b84c4115a671dcb9ade41c8b"
}
},
"fa9ea0d3234e41689c827485d0360885": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_9a079a30b4ae4bbc80122faf83e0ad59",
"placeholder": "",
"style": "IPY_MODEL_acda8e7582934fecbbf854e66e23f698",
"value": " 27.9M/27.9M [00:00<00:00, 44.5MB/s]"
}
},
"fd0ac7ed3d3146ec85913f4e05c4a2f6": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"fd2fe9ef6da64f72ab29d481d1739f5e": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_dbfeea8ee2374b8c8fa70431c35f281f",
"IPY_MODEL_84d27c45065e426badbfcfcdc8ff16b6",
"IPY_MODEL_fa9ea0d3234e41689c827485d0360885"
],
"layout": "IPY_MODEL_4cb119127b404f46a53012c62d004e28"
}
},
"ffabf89ecd9d48a5a3fc2a1c855ce080": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"state" : {}
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: examples/streamlit/streamlit_chat.py
================================================
import json
import requests
import streamlit as st
DEFAULT_FUNCTION_PROPERTIES = """
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
},
"required": ["location"]
}
""".strip()
# Session state for chat
if "messages" not in st.session_state:
st.session_state.messages = []
st.title("💬 Chatbot")
if "model" not in st.session_state:
if "model" in st.query_params:
st.session_state.model = st.query_params["model"]
else:
st.session_state.model = "small"
options = ["large", "small"]
selection = st.sidebar.segmented_control(
"Model", options, selection_mode="single", default=st.session_state.model
)
# st.session_state.model = selection
st.query_params.update({"model": selection})
instructions = st.sidebar.text_area(
"Instructions",
value="You are a helpful assistant that can answer questions and help with tasks.",
)
effort = st.sidebar.radio(
"Reasoning effort",
["low", "medium", "high"],
index=1,
)
st.sidebar.divider()
st.sidebar.subheader("Functions")
use_functions = st.sidebar.toggle("Use functions", value=False)
st.sidebar.subheader("Built-in Tools")
# Built-in Tools section
use_browser_search = st.sidebar.toggle("Use browser search", value=False)
use_code_interpreter = st.sidebar.toggle("Use code interpreter", value=False)
if use_functions:
function_name = st.sidebar.text_input("Function name", value="get_weather")
function_description = st.sidebar.text_area(
"Function description", value="Get the weather for a given city"
)
function_parameters = st.sidebar.text_area(
"Function parameters", value=DEFAULT_FUNCTION_PROPERTIES
)
else:
function_name = None
function_description = None
function_parameters = None
st.sidebar.divider()
temperature = st.sidebar.slider(
"Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01
)
max_output_tokens = st.sidebar.slider(
"Max output tokens", min_value=1, max_value=131072, value=30000, step=1000
)
st.sidebar.divider()
debug_mode = st.sidebar.toggle("Debug mode", value=False)
if debug_mode:
st.sidebar.divider()
st.sidebar.code(json.dumps(st.session_state.messages, indent=2), "json")
render_input = True
URL = (
"http://localhost:8081/v1/responses"
if selection == options[1]
else "http://localhost:8000/v1/responses"
)
def trigger_fake_tool(container):
function_output = st.session_state.get("function_output", "It's sunny!")
last_call = st.session_state.messages[-1]
if last_call.get("type") == "function_call":
st.session_state.messages.append(
{
"type": "function_call_output",
"call_id": last_call.get("call_id"),
"output": function_output,
}
)
run(container)
def run(container):
tools = []
if use_functions:
tools.append(
{
"type": "function",
"name": function_name,
"description": function_description,
"parameters": json.loads(function_parameters),
}
)
# Add browser_search tool if checkbox is checked
if use_browser_search:
tools.append({"type": "browser_search"})
if use_code_interpreter:
tools.append({"type": "code_interpreter"})
response = requests.post(
URL,
json={
"input": st.session_state.messages,
"stream": True,
"instructions": instructions,
"reasoning": {"effort": effort},
"metadata": {"__debug": debug_mode},
"tools": tools,
"temperature": temperature,
"max_output_tokens": max_output_tokens,
},
stream=True,
)
text_delta = ""
code_interpreter_sessions: dict[str, dict] = {}
_current_output_index = 0
for line in response.iter_lines(decode_unicode=True):
if not line or not line.startswith("data:"):
continue
data_str = line[len("data:") :].strip()
if not data_str:
continue
try:
data = json.loads(data_str)
except Exception:
continue
event_type = data.get("type", "")
output_index = data.get("output_index", 0)
if event_type == "response.output_item.added":
_current_output_index = output_index
output_type = data.get("item", {}).get("type", "message")
if output_type == "message":
output = container.chat_message("assistant")
placeholder = output.empty()
elif output_type == "reasoning":
output = container.chat_message("reasoning", avatar="🤔")
placeholder = output.empty()
elif output_type == "web_search_call":
output = container.chat_message("web_search_call", avatar="🌐")
output.code(
json.dumps(data.get("item", {}).get("action", {}), indent=4),
language="json",
)
placeholder = output.empty()
elif output_type == "code_interpreter_call":
item = data.get("item", {})
item_id = item.get("id")
message_container = container.chat_message(
"code_interpreter_call", avatar="🧪"
)
status_placeholder = message_container.empty()
code_placeholder = message_container.empty()
outputs_container = message_container.container()
code_text = item.get("code") or ""
if code_text:
code_placeholder.code(code_text, language="python")
code_interpreter_sessions[item_id] = {
"status": status_placeholder,
"code": code_placeholder,
"outputs": outputs_container,
"code_text": code_text,
"rendered_outputs": False,
}
placeholder = status_placeholder
text_delta = ""
elif event_type == "response.reasoning_text.delta":
output.avatar = "🤔"
text_delta += data.get("delta", "")
placeholder.markdown(text_delta)
elif event_type == "response.output_text.delta":
text_delta += data.get("delta", "")
placeholder.markdown(text_delta)
elif event_type == "response.output_item.done":
item = data.get("item", {})
if item.get("type") == "function_call":
with container.chat_message("function_call", avatar="🔨"):
st.markdown(f"Called `{item.get('name')}`")
st.caption("Arguments")
st.code(item.get("arguments", ""), language="json")
if item.get("type") == "web_search_call":
placeholder.markdown("✅ Done")
if item.get("type") == "code_interpreter_call":
item_id = item.get("id")
session = code_interpreter_sessions.get(item_id)
if session:
session["status"].markdown("✅ Done")
final_code = item.get("code") or session["code_text"]
if final_code:
session["code"].code(final_code, language="python")
session["code_text"] = final_code
outputs = item.get("outputs") or []
if outputs and not session["rendered_outputs"]:
with session["outputs"]:
st.markdown("**Outputs**")
for output_item in outputs:
output_type = output_item.get("type")
if output_type == "logs":
st.code(
output_item.get("logs", ""),
language="text",
)
elif output_type == "image":
st.image(
output_item.get("url", ""),
caption="Code interpreter image",
)
session["rendered_outputs"] = True
elif not outputs and not session["rendered_outputs"]:
with session["outputs"]:
st.caption("(No outputs)")
session["rendered_outputs"] = True
else:
placeholder.markdown("✅ Done")
elif event_type == "response.code_interpreter_call.in_progress":
item_id = data.get("item_id")
session = code_interpreter_sessions.get(item_id)
if session:
session["status"].markdown("⏳ Running")
else:
try:
placeholder.markdown("⏳ Running")
except Exception:
pass
elif event_type == "response.code_interpreter_call.interpreting":
item_id = data.get("item_id")
session = code_interpreter_sessions.get(item_id)
if session:
session["status"].markdown("🧮 Interpreting")
elif event_type == "response.code_interpreter_call.completed":
item_id = data.get("item_id")
session = code_interpreter_sessions.get(item_id)
if session:
session["status"].markdown("✅ Done")
else:
try:
placeholder.markdown("✅ Done")
except Exception:
pass
elif event_type == "response.code_interpreter_call_code.delta":
item_id = data.get("item_id")
session = code_interpreter_sessions.get(item_id)
if session:
session["code_text"] += data.get("delta", "")
if session["code_text"].strip():
session["code"].code(session["code_text"], language="python")
elif event_type == "response.code_interpreter_call_code.done":
item_id = data.get("item_id")
session = code_interpreter_sessions.get(item_id)
if session:
final_code = data.get("code") or session["code_text"]
session["code_text"] = final_code
if final_code:
session["code"].code(final_code, language="python")
elif event_type == "response.completed":
response = data.get("response", {})
if debug_mode:
container.expander("Debug", expanded=False).code(
response.get("metadata", {}).get("__debug", ""), language="text"
)
st.session_state.messages.extend(response.get("output", []))
if st.session_state.messages[-1].get("type") == "function_call":
with container.form("function_output_form"):
_function_output = st.text_input(
"Enter function output",
value=st.session_state.get("function_output", "It's sunny!"),
key="function_output",
)
st.form_submit_button(
"Submit function output",
on_click=trigger_fake_tool,
args=[container],
)
# Optionally handle other event types...
# Chat display
for msg in st.session_state.messages:
if msg.get("type") == "message":
with st.chat_message(msg["role"]):
for item in msg["content"]:
if (
item.get("type") == "text"
or item.get("type") == "output_text"
or item.get("type") == "input_text"
):
st.markdown(item["text"])
if item.get("annotations"):
annotation_lines = "\n".join(
f"- {annotation.get('url')}"
for annotation in item["annotations"]
if annotation.get("url")
)
st.caption(f"**Annotations:**\n{annotation_lines}")
elif msg.get("type") == "reasoning":
with st.chat_message("reasoning", avatar="🤔"):
for item in msg["content"]:
if item.get("type") == "reasoning_text":
st.markdown(item["text"])
elif msg.get("type") == "function_call":
with st.chat_message("function_call", avatar="🔨"):
st.markdown(f"Called `{msg.get('name')}`")
st.caption("Arguments")
st.code(msg.get("arguments", ""), language="json")
elif msg.get("type") == "function_call_output":
with st.chat_message("function_call_output", avatar="✅"):
st.caption("Output")
st.code(msg.get("output", ""), language="text")
elif msg.get("type") == "web_search_call":
with st.chat_message("web_search_call", avatar="🌐"):
st.code(json.dumps(msg.get("action", {}), indent=4), language="json")
st.markdown("✅ Done")
elif msg.get("type") == "code_interpreter_call":
with st.chat_message("code_interpreter_call", avatar="🧪"):
st.markdown("✅ Done")
if render_input:
# Input field
if prompt := st.chat_input("Type a message..."):
st.session_state.messages.append(
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": prompt}],
}
)
with st.chat_message("user"):
st.markdown(prompt)
run(st.container())
================================================
FILE: gpt-oss-mcp-server/README.md
================================================
# MCP Servers for gpt-oss reference tools
This directory contains MCP servers for the reference tools in the [gpt-oss](https://github.com/openai/gpt-oss) repository.
You can set up these tools behind MCP servers and use them in your applications.
For inference service that integrates with MCP, you can also use these as reference tools.
In particular, this directory contains a `build-system-prompt.py` script that will generate exactly the same system prompt as `reference-system-prompt.py`.
The build system prompt script show case all the care needed to automatically discover the tools and construct the system prompt before feeding it into Harmony.
## Usage
```bash
# Install the dependencies
uv pip install -r requirements.txt
```
```bash
# Assume we have harmony and gpt-oss installed
uv pip install mcp[cli]
# start the servers
mcp run -t sse browser_server.py:mcp
mcp run -t sse python_server.py:mcp
```
You can now use MCP inspector to play with the tools.
Once opened, set SSE to `http://localhost:8001/sse` and `http://localhost:8000/sse` respectively.
To compare the system prompt and see how to construct it via MCP service discovery, see `build-system-prompt.py`.
This script will generate exactly the same system prompt as `reference-system-prompt.py`.
================================================
FILE: gpt-oss-mcp-server/browser_server.py
================================================
import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Union, Optional
from mcp.server.fastmcp import Context, FastMCP
from gpt_oss.tools.simple_browser import SimpleBrowserTool
from gpt_oss.tools.simple_browser.backend import YouComBackend, ExaBackend
@dataclass
class AppContext:
browsers: dict[str, SimpleBrowserTool] = field(default_factory=dict)
def create_or_get_browser(self, session_id: str) -> SimpleBrowserTool:
if session_id not in self.browsers:
tool_backend = os.getenv("BROWSER_BACKEND", "exa")
if tool_backend == "youcom":
backend = YouComBackend(source="web")
elif tool_backend == "exa":
backend = ExaBackend(source="web")
else:
raise ValueError(f"Invalid tool backend: {tool_backend}")
self.browsers[session_id] = SimpleBrowserTool(backend=backend)
return self.browsers[session_id]
def remove_browser(self, session_id: str) -> None:
self.browsers.pop(session_id, None)
@asynccontextmanager
async def app_lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
yield AppContext()
# Pass lifespan to server
mcp = FastMCP(
name="browser",
instructions=r"""
Tool for browsing.
The `cursor` appears in brackets before each browsing display: `[{cursor}]`.
Cite information from the tool using the following format:
`【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.
Do not quote more than 10 words directly from the tool output.
sources=web
""".strip(),
lifespan=app_lifespan,
port=8001,
)
@mcp.tool(
name="search",
title="Search for information",
description=
"Searches for information related to `query` and displays `topn` results.",
)
async def search(ctx: Context,
query: str,
topn: int = 10,
source: Optional[str] = None) -> str:
"""Search for information related to a query"""
browser = ctx.request_context.lifespan_context.create_or_get_browser(
ctx.client_id)
messages = []
async for message in browser.search(query=query, topn=topn, source=source):
if message.content and hasattr(message.content[0], 'text'):
messages.append(message.content[0].text)
return "\n".join(messages)
@mcp.tool(
name="open",
title="Open a link or page",
description="""
Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.
Valid link ids are displayed with the formatting: `【{id}†.*】`.
If `cursor` is not provided, the most recent page is implied.
If `id` is a string, it is treated as a fully qualified URL associated with `source`.
If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.
Use this function without `id` to scroll to a new location of an opened page.
""".strip(),
)
async def open_link(ctx: Context,
id: Union[int, str] = -1,
cursor: int = -1,
loc: int = -1,
num_lines: int = -1,
view_source: bool = False,
source: Optional[str] = None) -> str:
"""Open a link or navigate to a page location"""
browser = ctx.request_context.lifespan_context.create_or_get_browser(
ctx.client_id)
messages = []
async for message in browser.open(id=id,
cursor=cursor,
loc=loc,
num_lines=num_lines,
view_source=view_source,
source=source):
if message.content and hasattr(message.content[0], 'text'):
messages.append(message.content[0].text)
return "\n".join(messages)
@mcp.tool(
name="find",
title="Find pattern in page",
description=
"Finds exact matches of `pattern` in the current page, or the page given by `cursor`.",
)
async def find_pattern(ctx: Context, pattern: str, cursor: int = -1) -> str:
"""Find exact matches of a pattern in the current page"""
browser = ctx.request_context.lifespan_context.create_or_get_browser(
ctx.client_id)
messages = []
async for message in browser.find(pattern=pattern, cursor=cursor):
if message.content and hasattr(message.content[0], 'text'):
messages.append(message.content[0].text)
return "\n".join(messages)
================================================
FILE: gpt-oss-mcp-server/build-system-prompt.py
================================================
import datetime
import asyncio
from gpt_oss.tokenizer import get_tokenizer
from openai_harmony import (
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
ReasoningEffort,
Role,
SystemContent,
ToolNamespaceConfig,
ToolDescription,
load_harmony_encoding,
)
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import ListToolsResult
async def list_server_and_tools(server_url: str):
async with sse_client(url=server_url) as streams, ClientSession(
*streams) as session:
initialize_response = await session.initialize()
list_tools_response = await session.list_tools()
return initialize_response, list_tools_response
def trim_schema(schema: dict) -> dict:
# Turn JSON Schema from MCP generated into Harmony's variant.
if "title" in schema:
del schema["title"]
if "default" in schema and schema["default"] is None:
del schema["default"]
if "anyOf" in schema:
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}] into "type": ["type-1", "type-2"]
# if there's more than 1 types, also remove "null" type as Harmony will just ignore it
types = [
type_dict["type"] for type_dict in schema["anyOf"]
if type_dict["type"] != 'null'
]
schema["type"] = types
del schema["anyOf"]
if "properties" in schema:
schema["properties"] = {
k: trim_schema(v)
for k, v in schema["properties"].items()
}
return schema
def post_process_tools_description(
list_tools_result: ListToolsResult) -> ListToolsResult:
# Adapt the MCP tool result for Harmony
for tool in list_tools_result.tools:
tool.inputSchema = trim_schema(tool.inputSchema)
# Some tools schema don't need to be part of the prompt (e.g. simple text in text out for Python)
list_tools_result.tools = [
tool for tool in list_tools_result.tools
if getattr(tool.annotations, "include_in_prompt", True)
]
return list_tools_result
tokenizer = get_tokenizer()
tools_urls = [
"http://localhost:8001/sse", # browser
"http://localhost:8000/sse", # python
]
harmony_tool_descriptions = []
for tools_url in tools_urls:
initialize_response, list_tools_response = asyncio.run(
list_server_and_tools(tools_url))
list_tools_response = post_process_tools_description(list_tools_response)
tool_from_mcp = ToolNamespaceConfig(
name=initialize_response.serverInfo.name,
description=initialize_response.instructions,
tools=[
ToolDescription.new(name=tool.name,
description=tool.description,
parameters=tool.inputSchema)
for tool in list_tools_response.tools
])
harmony_tool_descriptions.append(tool_from_mcp)
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
system_message_content = (SystemContent.new().with_reasoning_effort(
ReasoningEffort.LOW).with_conversation_start_date(
datetime.datetime.now().strftime("%Y-%m-%d")))
for tool_description in harmony_tool_descriptions:
system_message_content = system_message_content.with_tools(
tool_description)
system_message = Message.from_role_and_content(Role.SYSTEM,
system_message_content)
developer_message_content = DeveloperContent.new().with_instructions("")
developer_message = Message.from_role_and_content(Role.DEVELOPER,
developer_message_content)
messages = [system_message, developer_message]
conversation = Conversation.from_messages(messages)
tokens = encoding.render_conversation(conversation)
system_message = tokenizer.decode(tokens)
print(system_message)
================================================
FILE: gpt-oss-mcp-server/pyproject.toml
================================================
[project]
name = "gpt-oss-mcp-server"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = [
"mcp[cli]>=1.12.2",
# "gpt_oss"
]
================================================
FILE: gpt-oss-mcp-server/python_server.py
================================================
from mcp.server.fastmcp import FastMCP
from gpt_oss.tools.python_docker.docker_tool import PythonTool
from openai_harmony import Message, TextContent, Author, Role
# Pass lifespan to server
mcp = FastMCP(
name="python",
instructions=r"""
Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).
When you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you.
""".strip(),
)
@mcp.tool(
name="python",
title="Execute Python code",
description="""
Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).
When you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you.
""",
annotations={
# Harmony format don't want this schema to be part of it because it's simple text in text out
"include_in_prompt": False,
})
async def python(code: str) -> str:
tool = PythonTool()
messages = []
async for message in tool.process(
Message(author=Author(role=Role.TOOL, name="python"),
content=[TextContent(text=code)])):
messages.append(message)
return "\n".join([message.content[0].text for message in messages])
================================================
FILE: gpt-oss-mcp-server/reference-system-prompt.py
================================================
import datetime
from gpt_oss.tools.simple_browser import SimpleBrowserTool
from gpt_oss.tools.simple_browser.backend import YouComBackend
from gpt_oss.tools.python_docker.docker_tool import PythonTool
from gpt_oss.tokenizer import tokenizer
from openai_harmony import (
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
ReasoningEffort,
Role,
SystemContent,
load_harmony_encoding,
)
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
system_message_content = (SystemContent.new().with_reasoning_effort(
ReasoningEffort.LOW).with_conversation_start_date(
datetime.datetime.now().strftime("%Y-%m-%d")))
backend = YouComBackend(source="web")
browser_tool = SimpleBrowserTool(backend=backend)
system_message_content = system_message_content.with_tools(
browser_tool.tool_config)
python_tool = PythonTool()
system_message_content = system_message_content.with_tools(
python_tool.tool_config)
system_message = Message.from_role_and_content(Role.SYSTEM,
system_message_content)
developer_message_content = DeveloperContent.new().with_instructions("")
developer_message = Message.from_role_and_content(Role.DEVELOPER,
developer_message_content)
messages = [system_message, developer_message]
conversation = Conversation.from_messages(messages)
tokens = encoding.render_conversation(conversation)
system_message = tokenizer.decode(tokens)
print(system_message)
================================================
FILE: gpt_oss/__init__.py
================================================
================================================
FILE: gpt_oss/chat.py
================================================
"""
Harmony chat with tools
"""
import atexit
import argparse
import asyncio
import datetime
import os
from pathlib import Path
try:
import gnureadline as readline
except ImportError:
import readline
import torch
import termcolor
from gpt_oss.tools import apply_patch
from gpt_oss.tools.simple_browser import SimpleBrowserTool
from gpt_oss.tools.simple_browser.backend import YouComBackend
from gpt_oss.tools.python_docker.docker_tool import PythonTool
from openai_harmony import (
Author,
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
ReasoningEffort,
Role,
StreamableParser,
StreamState,
SystemContent,
TextContent,
ToolDescription,
load_harmony_encoding,
)
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
"low": ReasoningEffort.LOW,
}
def get_user_input():
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if rank == 0:
user_input = input()
else:
user_input = ""
user_input_list = [user_input]
if torch.distributed.is_initialized():
torch.distributed.broadcast_object_list(user_input_list, 0)
return user_input_list[0]
def main(args):
match args.backend:
case "triton":
from gpt_oss.triton.model import TokenGenerator as TritonGenerator
from gpt_oss.torch.utils import init_distributed
device = init_distributed()
generator = TritonGenerator(args.checkpoint, args.context, device)
case "torch":
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
from gpt_oss.torch.utils import init_distributed
device = init_distributed()
generator = TorchGenerator(args.checkpoint, device)
case "vllm":
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2)
case _:
raise ValueError(f"Invalid backend: {args.backend}")
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
system_message_content = (
SystemContent.new()
.with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort])
.with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d"))
)
if args.browser:
backend = YouComBackend(
source="web",
)
browser_tool = SimpleBrowserTool(backend=backend)
system_message_content = system_message_content.with_tools(browser_tool.tool_config)
if args.python:
python_tool = PythonTool()
system_message_content = system_message_content.with_tools(python_tool.tool_config)
system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)
messages = [system_message]
if args.apply_patch:
apply_patch_instructions = Path(apply_patch.__file__).parent / "apply_patch.md"
developer_message = ""
if args.developer_message:
developer_message = args.developer_message + "\n"
developer_message += apply_patch_instructions.read_text()
developer_message_content = (
DeveloperContent.new()
.with_instructions(developer_message)
.with_function_tools([
ToolDescription.new(
"apply_patch",
"Patch a file",
parameters={
"type": "string",
"description": "Formatted patch code",
"default": "*** Begin Patch\n*** End Patch\n",
}
),
])
)
messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))
elif args.developer_message:
developer_message_content = DeveloperContent.new().with_instructions(args.developer_message)
messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))
else:
developer_message_content = None
if args.raw:
conversation = Conversation.from_messages(messages)
tokens = encoding.render_conversation(conversation)
system_message = encoding.decode(tokens)
print(system_message, flush=True, end="")
empty_user_message_tokens = encoding.render(Message.from_role_and_content(Role.USER, ""))
user_message_start = encoding.decode(empty_user_message_tokens[:-1])
user_message_end = encoding.decode(empty_user_message_tokens[-1:])
else:
# System message
print(termcolor.colored("System Message:", "cyan"), flush=True)
print(termcolor.colored("Model Identity:", "cyan"), system_message_content.model_identity, flush=True)
print(termcolor.colored("Reasoning Effort:", "cyan"), system_message_content.reasoning_effort, flush=True)
print(termcolor.colored("Conversation Start Date:", "cyan"), system_message_content.conversation_start_date, flush=True)
print(termcolor.colored("Knowledge Cutoff:", "cyan"), system_message_content.knowledge_cutoff, flush=True)
print(termcolor.colored("Browser Tool:", "cyan"), "Enabled" if args.browser else "Disabled", flush=True)
print(termcolor.colored("Python Tool:", "cyan"), "Enabled" if args.python else "Disabled", flush=True)
print(termcolor.colored("Apply Patch Function:", "cyan"), "Enabled" if args.apply_patch else "Disabled", flush=True)
if developer_message_content:
print(termcolor.colored("Developer Message:", "yellow"), flush=True)
print(developer_message_content.instructions, flush=True)
# Print the system message and the user message start
MESSAGE_PADDING = 12
while True:
last_message = messages[-1]
if last_message.recipient is None:
if args.raw:
print(user_message_start, end="", flush=True)
user_message = get_user_input()
print(user_message_end, flush=True, end="")
else:
print(termcolor.colored("User:".ljust(MESSAGE_PADDING), "red"), flush=True)
user_message = get_user_input()
user_message = Message.from_role_and_content(Role.USER, user_message)
messages.append(user_message)
else:
# Tool or function call
if last_message.recipient.startswith("browser."):
assert args.browser, "Browser tool is not enabled"
tool_name = "Search"
async def run_tool():
results = []
async for msg in browser_tool.process(last_message):
results.append(msg)
return results
result = asyncio.run(run_tool())
messages += result
elif last_message.recipient.startswith("python"):
assert args.python, "Python tool is not enabled"
tool_name = "Python"
async def run_tool():
results = []
async for msg in python_tool.process(last_message):
results.append(msg)
return results
result = asyncio.run(run_tool())
messages += result
elif last_message.recipient == "functions.apply_patch":
assert args.apply_patch, "Apply patch tool is not enabled"
tool_name = "Apply Patch"
text = last_message.content[0].text
tool_output = None
if text.startswith("{"):
# this is json, try to extract the patch from it
import json
try:
some_dict = json.loads(text)
_, text = some_dict.popitem()
except Exception as e:
tool_output = f"Error parsing JSON: {e}"
if tool_output is None:
try:
tool_output = apply_patch.apply_patch(text)
except Exception as e:
tool_output = f"Error applying patch: {e}"
message = (
Message(
author=Author.new(Role.TOOL, last_message.recipient),
content=[TextContent(text=tool_output)]
)
.with_recipient("assistant")
)
if last_message.channel:
message = message.with_channel(last_message.channel)
result = [message]
messages += result
else:
raise ValueError(f"Unknown tool or function call: {last_message.recipient}")
# Print the tool or function call result
if args.raw:
rendered_result = encoding.render_conversation(Conversation.from_messages(result))
print(encoding.decode(rendered_result), flush=True, end="")
else:
print(termcolor.colored(f"{tool_name} output:".ljust(MESSAGE_PADDING), "magenta"), flush=True)
if tool_name == "Search" and not args.show_browser_results:
print("[Search results fed to the model]")
else:
print(result[0].content[0].text)
conversation = Conversation.from_messages(messages)
tokens = encoding.render_conversation_for_completion(
conversation, Role.ASSISTANT
)
if args.raw:
# Print the last two tokens, which are the start of the assistant message
print(encoding.decode(tokens[-2:]), flush=True, end="")
parser = StreamableParser(encoding, role=Role.ASSISTANT)
field_created = False
current_output_text = ""
output_text_delta_buffer = ""
for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()):
parser.process(predicted_token)
if args.raw:
print(encoding.decode([predicted_token]), end="", flush=True)
continue
if parser.state == StreamState.EXPECT_START:
print("") # new line
field_created = False
if not parser.last_content_delta:
continue
if not field_created:
field_created = True
if parser.current_channel == "final":
print(termcolor.colored("Assistant:", "green"), flush=True)
elif parser.current_recipient is not None:
print(termcolor.colored(f"Tool call to {parser.current_recipient}:", "cyan"), flush=True)
else:
print(termcolor.colored("CoT:", "yellow"), flush=True)
should_send_output_text_delta = True
output_text_delta_buffer += parser.last_content_delta
if args.browser:
updated_output_text, _annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text + output_text_delta_buffer)
output_text_delta_buffer = updated_output_text[len(current_output_text):]
if has_partial_citations:
should_send_output_text_delta = False
if should_send_output_text_delta:
print(output_text_delta_buffer, end="", flush=True)
current_output_text += output_text_delta_buffer
output_text_delta_buffer = ""
messages += parser.messages
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Chat example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"checkpoint",
metavar="FILE",
type=str,
help="Path to the SafeTensors checkpoint",
)
parser.add_argument(
"-r",
"--reasoning-effort",
metavar="REASONING_EFFORT",
type=str,
default="low",
choices=["high", "medium", "low"],
help="Reasoning effort",
)
parser.add_argument(
"-a",
"--apply-patch",
action="store_true",
help="Make apply_patch function available to the model",
)
parser.add_argument(
"-b",
"--browser",
default=False,
action="store_true",
help="Use browser tool",
)
parser.add_argument(
"--show-browser-results",
default=False,
action="store_true",
help="Show browser results",
)
parser.add_argument(
"-p",
"--python",
default=False,
action="store_true",
help="Use python tool",
)
parser.add_argument(
"--developer-message",
default="",
help="Developer message",
)
parser.add_argument(
"-c",
"--context",
metavar="CONTEXT",
type=int,
default=8192,
help="Max context length",
)
parser.add_argument(
"--raw",
default=False,
action="store_true",
help="Raw mode (does not render Harmony encoding)",
)
parser.add_argument(
"--backend",
type=str,
default="triton",
choices=["triton", "torch", "vllm"],
help="Inference backend",
)
args = parser.parse_args()
if int(os.environ.get("WORLD_SIZE", 1)) == 1:
histfile = os.path.join(os.path.expanduser("~"), ".chat")
try:
readline.read_history_file(histfile)
readline.set_history_length(10000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file, histfile)
main(args)
================================================
FILE: gpt_oss/evals/README.md
================================================
# `gpt_oss.evals`
This module is a reincarnation of [simple-evals](https://github.com/openai/simple-evals) adapted for gpt-oss. It lets you
run GPQA and HealthBench against a runtime that supports Responses API on `localhost:8080/v1`.
================================================
FILE: gpt_oss/evals/__init__.py
================================================
================================================
FILE: gpt_oss/evals/__main__.py
================================================
import argparse
import json
from datetime import datetime
from . import report
from .basic_eval import BasicEval
from .gpqa_eval import GPQAEval
from .aime_eval import AIME25Eval
from .healthbench_eval import HealthBenchEval
from .chat_completions_sampler import (
OPENAI_SYSTEM_MESSAGE_API,
ChatCompletionsSampler,
)
from .responses_sampler import ResponsesSampler
def main():
parser = argparse.ArgumentParser(
description="Evaluate the models.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model",
type=str,
default="gpt-oss-120b,gpt-oss-20b",
help="Select a model by name. Accepts a comma-separated list.",
)
parser.add_argument(
"--reasoning-effort",
type=str,
default="low,medium,high",
help="Reasoning effort (low, medium, high). Accepts a comma-separated list.",
)
parser.add_argument(
"--sampler",
type=str,
choices=["responses", "chat_completions"],
default="responses",
help="Sampler backend to use for models.",
)
parser.add_argument(
"--base-url",
type=str,
default="http://localhost:8000/v1",
help="Base URL for the API.",
)
parser.add_argument(
"--eval",
type=str,
default="gpqa,healthbench,healthbench_hard,healthbench_consensus,aime25",
help="Select an eval by name. Accepts a comma-separated list.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature",
)
parser.add_argument(
"--n-threads",
type=int,
default=1584,
help="Number of threads to run.",
)
parser.add_argument(
"--debug", action="store_true", help="Run in debug mode"
)
parser.add_argument(
"--examples", type=int, help="Number of examples to use (overrides default)"
)
args = parser.parse_args()
sampler_cls = ResponsesSampler if args.sampler == "responses" else ChatCompletionsSampler
models = {}
for model_name in args.model.split(","):
for reasoning_effort in args.reasoning_effort.split(","):
models[f"{model_name}-{reasoning_effort}"] = sampler_cls(
model=model_name,
reasoning_model=True,
reasoning_effort=reasoning_effort,
temperature=args.temperature,
base_url=args.base_url,
max_tokens=131_072,
)
print(f"Running with args {args}")
grading_sampler = ChatCompletionsSampler(
model="gpt-4.1-2025-04-14",
system_message=OPENAI_SYSTEM_MESSAGE_API,
max_tokens=2048,
base_url="https://api.openai.com/v1",
)
def get_evals(eval_name, debug_mode):
num_examples = (
args.examples if args.examples is not None else (5 if debug_mode else None)
)
# Set num_examples = None to reproduce full evals
match eval_name:
case "basic":
return BasicEval()
case "gpqa":
return GPQAEval(
n_repeats=1 if args.debug else 8,
num_examples=num_examples,
debug=debug_mode,
n_threads=args.n_threads or 1,
)
case "healthbench":
return HealthBenchEval(
grader_model=grading_sampler,
num_examples=10 if debug_mode else num_examples,
n_repeats=1,
n_threads=args.n_threads or 1,
subset_name=None,
)
case "healthbench_hard":
return HealthBenchEval(
grader_model=grading_sampler,
num_examples=10 if debug_mode else num_examples,
n_repeats=1,
n_threads=args.n_threads or 1,
subset_name="hard",
)
case "healthbench_consensus":
return HealthBenchEval(
grader_model=grading_sampler,
num_examples=10 if debug_mode else num_examples,
n_repeats=1,
n_threads=args.n_threads or 1,
subset_name="consensus",
)
case "aime25":
return AIME25Eval(
n_repeats=1 if args.debug else 8,
num_examples=num_examples,
n_threads=args.n_threads or 1,
)
case _:
raise Exception(f"Unrecognized eval type: {eval_name}")
evals = {}
for eval_name in args.eval.split(","):
evals[eval_name] = get_evals(eval_name, args.debug)
debug_suffix = "_DEBUG" if args.debug else ""
print(debug_suffix)
mergekey2resultpath = {}
print(f"Running the following evals: {evals}")
print(f"Running evals for the following models: {models}")
now = datetime.now()
date_str = now.strftime("%Y%m%d_%H%M%S")
for model_name, sampler in models.items():
model_name = model_name.replace("/", "__")
for eval_name, eval_obj in evals.items():
result = eval_obj(sampler)
# ^^^ how to use a sampler
file_stem = f"{eval_name}_{model_name}_temp{args.temperature}"
# file stem should also include the year, month, day, and time in hours and minutes
file_stem += f"_{date_str}"
report_filename = f"/tmp/{file_stem}{debug_suffix}.html"
print(f"Writing report to {report_filename}")
with open(report_filename, "w") as fh:
fh.write(report.make_report(result))
assert result.metrics is not None
metrics = result.metrics | {"score": result.score}
# Sort metrics by key
metrics = dict(sorted(metrics.items()))
print(metrics)
result_filename = f"/tmp/{file_stem}{debug_suffix}.json"
with open(result_filename, "w") as f:
f.write(json.dumps(metrics, indent=2))
print(f"Writing results to {result_filename}")
full_result_filename = f"/tmp/{file_stem}{debug_suffix}_allresults.json"
with open(full_result_filename, "w") as f:
result_dict = {
"score": result.score,
"metrics": result.metrics,
"htmls": result.htmls,
"convos": result.convos,
"metadata": result.metadata,
}
f.write(json.dumps(result_dict, indent=2))
print(f"Writing all results to {full_result_filename}")
mergekey2resultpath[f"{file_stem}"] = result_filename
merge_metrics = []
for eval_model_name, result_filename in mergekey2resultpath.items():
try:
result = json.load(open(result_filename, "r+"))
except Exception as e:
print(e, result_filename)
continue
result = result.get("f1_score", result.get("score", None))
eval_name = eval_model_name[: eval_model_name.find("_")]
model_name = eval_model_name[eval_model_name.find("_") + 1 :]
merge_metrics.append(
{"eval_name": eval_name, "model_name": model_name, "metric": result}
)
print(merge_metrics)
return merge_metrics
if __name__ == "__main__":
main()
================================================
FILE: gpt_oss/evals/abcd_grader.py
================================================
import re
import sys
_PATTERNS = [
# 0)"**Answer:** A" or "*Answers* – B", i.e. markdown‐wrapped "Answer(s)" with an unwrapped letter.
re.compile(
r'''(?ix) # case‐insensitive, ignore‐space
(?:\*{1,2}|_{1,2}) # leading *…* or _…_
Answer[s]? # Answer or Answers
\s*[:\-–]? # optional separator
(?:\*{1,2}|_{1,2}) # closing wrapper
\s* # optional space
([ABCD])\b # the actual letter
''',
re.X
),
# 0.1)
re.compile(r'''(?ix) # ignore case, allow verbose mode
^\s* # optional leading whitespace
(?:\*{1,2}|_{1,2})? # optional markdown wrapper
Answer:? # the word 'answer' with an optional colon
(?:\*{1,2}|_{1,2})? # optional markdown wrapper again
\s*:?\s* # optional colon with optional spaces
(?:\*{1,2}|_{1,2})? # optional markdown wrapper before letter
([ABCD]) # capture the letter
(?:\*{1,2}|_{1,2})? # optional markdown wrapper after letter
\s* # optional trailing whitespace, end of line
''', re.MULTILINE),
# 1) Answer: (C) or Answers: (B)
re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([ABCD])\s*\)'),
# 2) Answer: C or Answers – D
re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([ABCD])\b'),
# 3) Option B or Choice: C
re.compile(r'(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([ABCD])\b'),
# 7) LaTeX \boxed{...A...}, catches both \boxed{A} and
# \boxed{\text{A } 2.08\times10^{-6}\,\mathrm{m}} etc.
re.compile(r'(?x)\\boxed\{[^}]*?([ABCD])[^}]*\}', re.MULTILINE),
# 7.5) LaTeX \boxed{\textbf{...C...}}
re.compile(r'(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE),
# 7.51) LaTeX \boxed{\text{...C...}}
re.compile(r'(?x)\\boxed\{[^}]*?\\text\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE),
# 4) bare singletons: (A) [B]
re.compile(r'(?x)(? str | None:
"""
Scan text (with Markdown/LaTeX wrappers intact) and return
'A', 'B', 'C', or 'D' if a correct-answer declaration is found.
Otherwise return None.
"""
matches = []
for prio, pat in enumerate(_PATTERNS):
m = pat.search(text)
if m:
letter = m.group(1).upper()
if letter in 'ABCD':
matches.append((prio, m, letter))
matches.sort(key=lambda triple: (
triple[0],
len(triple[1].group(0))
))
for _, match, letter in matches:
return letter
return text.removeprefix('**')[:1]
def main():
if len(sys.argv) > 1:
# Process files
for fn in sys.argv[1:]:
with open(fn, encoding='utf8') as fp:
text = fp.read()
ans = extract_abcd(text)
print(f"{fn} ➜ {ans!r}")
else:
# Read from stdin
for line in sys.stdin:
ans = extract_abcd(line)
print(f"{line} ➜ {ans!r}")
if __name__ == "__main__":
main()
================================================
FILE: gpt_oss/evals/aime_eval.py
================================================
"""
AIME 2025: https://huggingface.co/datasets/opencompass/AIME2025
"""
import random
import re
import pandas
from . import report
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
AIME_TEMPLATE = """
{question}
Please reason step by step, and put your final answer within \\boxed{{}}.
"""
def format_aime_question(row):
return AIME_TEMPLATE.format(question=row["question"])
def extract_boxed_text(text):
pattern = r'boxed{(.*?)}|framebox{(.*?)}'
matches = re.findall(pattern, text, re.DOTALL)
if matches:
for match in matches[::-1]:
for group in match:
if group != "":
return group.split(',')[-1].strip()
pattern = r'\d+' # get the last integer if no pattern found
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[-1]
return ""
def normalize_number(s):
match = re.match(r"\d+", s) # match digits from the start
if not match:
return None
return match.group(0)
class AIME25Eval(Eval):
def __init__(
self,
n_repeats: int = 4,
num_examples: int | None = None, # restrict to a subset of the data for debugging
n_threads: int = 1,
):
path1 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl"
df1 = pandas.read_json(path1, lines=True)
path2 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl"
df2 = pandas.read_json(path2, lines=True)
examples = [row.to_dict() for _, row in df1.iterrows()] + [row.to_dict() for _, row in df2.iterrows()]
examples = [{
"question": row["question"],
"answer": normalize_number(row["answer"]) if isinstance(row["answer"], str) else row["answer"],
} for row in examples]
rng = random.Random(0)
if num_examples:
assert n_repeats == 1, "n_repeats only supported for num_examples = None"
examples = rng.sample(examples, num_examples)
examples = examples * n_repeats
examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples]
self.examples = examples
self.n_repeats = n_repeats
self.n_threads = n_threads
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
prompt_messages = [
sampler._pack_message(
content=format_aime_question(row), role="user"
)
]
sampler_response = sampler(prompt_messages)
response_text = sampler_response.response_text
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
extracted_answer = extract_boxed_text(response_text)
correct_answer = int(row["answer"])
try: # All AIME answers are integers, so we convert the extracted answer to an integer
extracted_answer = int(extracted_answer)
except (ValueError, TypeError):
extracted_answer = None
score = 1.0 if extracted_answer == correct_answer else 0.0
html = report.jinja_env.from_string(report.HTML_JINJA).render(
prompt_messages=actual_queried_prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=correct_answer,
extracted_answer=extracted_answer,
)
convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(
html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
)
results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads)
return report.aggregate_results(results)
================================================
FILE: gpt_oss/evals/basic_eval.py
================================================
"""
Basic eval
"""
from . import report
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
class BasicEval(Eval):
def __init__(self,):
self.examples = [{
"question": "hi",
"answer": "hi, how can i help?",
}]
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
sampler_response = sampler([
sampler._pack_message(content=row["question"], role="user")
])
response_text = sampler_response.response_text
extracted_answer = response_text
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
score = 1.0 if len(extracted_answer) > 0 else 0.0
html = report.jinja_env.from_string(report.HTML_JINJA).render(
prompt_messages=actual_queried_prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=row["answer"],
extracted_answer=extracted_answer,
)
convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(
html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
)
results = report.map_with_progress(fn, self.examples, num_threads=1)
return report.aggregate_results(results)
================================================
FILE: gpt_oss/evals/chat_completions_sampler.py
================================================
import time
from typing import Any
import openai
from openai import OpenAI
from .types import MessageList, SamplerBase, SamplerResponse
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
)
class ChatCompletionsSampler(SamplerBase):
"""Sample from a Chat Completions compatible API."""
def __init__(
self,
model: str = "gpt-3.5-turbo",
system_message: str | None = None,
temperature: float = 0.5,
max_tokens: int = 1024,
reasoning_model: bool = False,
reasoning_effort: str | None = None,
base_url: str = "http://localhost:8000/v1",
):
self.client = OpenAI(base_url=base_url, timeout=24 * 60 * 60)
self.model = model
self.system_message = system_message
self.temperature = temperature
self.max_tokens = max_tokens
self.reasoning_model = reasoning_model
self.reasoning_effort = reasoning_effort
self.image_format = "url"
def _pack_message(self, role: str, content: Any) -> dict[str, Any]:
return {"role": str(role), "content": content}
def __call__(self, message_list: MessageList) -> SamplerResponse:
if self.system_message:
message_list = [
self._pack_message("system", self.system_message)
] + message_list
trial = 0
while True:
try:
if self.reasoning_model:
response = self.client.chat.completions.create(
model=self.model,
messages=message_list,
reasoning_effort=self.reasoning_effort,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
else:
response = self.client.chat.completions.create(
model=self.model,
messages=message_list,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
choice = response.choices[0]
content = choice.message.content
if getattr(choice.message, "reasoning", None):
message_list.append(self._pack_message("assistant", choice.message.reasoning))
if not content:
raise ValueError("OpenAI API returned empty response; retrying")
return SamplerResponse(
response_text=content,
response_metadata={"usage": response.usage},
actual_queried_message_list=message_list,
)
except openai.BadRequestError as e:
print("Bad Request Error", e)
return SamplerResponse(
response_text="No response (bad request).",
response_metadata={"usage": None},
actual_queried_message_list=message_list,
)
except Exception as e:
exception_backoff = 2 ** trial # exponential back off
print(
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
e,
)
time.sleep(exception_backoff)
trial += 1
# unknown error shall throw exception
================================================
FILE: gpt_oss/evals/gpqa_eval.py
================================================
"""
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
https://arxiv.org/abs/2311.12022
"""
import random
import pandas
from . import report
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from .abcd_grader import extract_abcd
QUERY_TEMPLATE_MULTICHOICE = """
{Question}
(A) {A}
(B) {B}
(C) {C}
(D) {D}
Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'.
""".strip()
def format_multichoice_question(row):
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
class GPQAEval(Eval):
def __init__(
self,
n_repeats: int = 8,
variant: str = "diamond",
num_examples: int | None = None, # restrict to a subset of the data for debugging
debug: bool = False,
n_threads: int = 1,
):
df = pandas.read_csv(
f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv"
)
rng = random.Random(0)
if debug:
examples = [row.to_dict() for _, row in df.iterrows() if "ESPRESSO spectrograph, please" in row["Question"]]
else:
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
assert n_repeats == 1, "n_repeats only supported for num_examples = None"
examples = rng.sample(examples, num_examples)
examples = examples * n_repeats
examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples]
self.examples = examples
self.n_repeats = n_repeats
self.n_threads = n_threads
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
choices = [
row["Correct Answer"],
row["Incorrect Answer 1"],
row["Incorrect Answer 2"],
row["Incorrect Answer 3"],
]
choices = [choices[i] for i in row["permutation"]]
correct_index = choices.index(row["Correct Answer"])
correct_answer = "ABCD"[correct_index]
choices_dict = dict(
A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=row["Question"]
)
prompt_messages = [
sampler._pack_message(
content=format_multichoice_question(choices_dict), role="user"
)
]
sampler_response = sampler(prompt_messages)
response_text = sampler_response.response_text
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
extracted_answer = extract_abcd(response_text)
score = 1.0 if extracted_answer == correct_answer else 0.0
html = report.jinja_env.from_string(report.HTML_JINJA).render(
prompt_messages=actual_queried_prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=correct_answer,
extracted_answer=extracted_answer,
)
convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(
html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
)
results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads)
return report.aggregate_results(results)
if __name__ == "__main__":
import json
import sys
with open(sys.argv[1], "r") as f:
results = json.load(f)
passes = 0
for convo, html in zip(results["convos"], results["htmls"]):
message = convo[-1]["content"]
import re
# the ground truth is in Correct Answer: A
in the html
ground_truth = re.search(r"Correct Answer: (A|B|C|D)
", html)
ground_truth = ground_truth.group(1)
extracted_answer = extract_abcd(message)
if extracted_answer == ground_truth:
passes += 1
elif len(message) > 15:
print("no match:", message)
print("ground truth:", ground_truth)
print("extracted answer:", extracted_answer)
print("--------------------------------")
pass_rate = passes / len(results["convos"])
print(f"pass@1: {pass_rate}")
================================================
FILE: gpt_oss/evals/healthbench_eval.py
================================================
"""
This script evaluates the performance of a model on the HealthBench dataset.
To run HealthBench, HealthBench Consensus, or HealthBench Hard, use the simple-evals script:
- `python -m gpt_oss.evals --eval=healthbench --model=gpt-oss-120b`
- `python -m gpt_oss.evals --eval=healthbench_consensus --model=gpt-oss-120b`
- `python -m gpt_oss.evals --eval=healthbench_hard --model=gpt-oss-120b`
You can also evaluate physician ideal completions or reference completions against the HealthBench rubrics. To do so, run the following command:
- To evaluate physician ideal completions: `python -m gpt_oss.evals.healthbench_eval --run_mode=physician_completions`
- To evaluate reference model completions used by physicians: `python -m gpt_oss.evals.healthbench_eval --run_mode=physician_completion_references`
"""
import argparse
import copy
import hashlib
import json
import random
import re
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Literal
import blobfile as bf
import numpy as np
from . import report
from .chat_completions_sampler import (
OPENAI_SYSTEM_MESSAGE_API,
ChatCompletionsSampler,
)
from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
INPUT_PATH = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_eval.jsonl"
INPUT_PATH_HARD = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/hard_2025-05-08-21-00-10.jsonl"
INPUT_PATH_CONSENSUS = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/consensus_2025-05-09-20-00-46.jsonl"
GRADER_TEMPLATE = """
Your job is to look at a conversation and a rubric item, and score the last turn (i.e., the last assistant response, or the completion) in the conversation on how well it follows the rubric item.
# Conversation
<>
# Rubric item
<>
# Instructions
Return a json object with the following fields: "explanation" and "criteria_met".
- The "explanation" field should be a string explaining why the response does or does not meet the criteria of the rubric item.
- The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true if all of the criteria are met.
- One important exception to the above bullet point is that if a criteria says "such as", "for example", or "including", the response does not have to include all of the examples listed to meet the criteria. For example, if the criteria says "States that oral iron supplements can lead to unpleasant gastrointestinal side effects such as nausea, vomiting, and constipation", and the response just says that oral iron supplements can lead to unpleasant gastrointestinal side effects such as cramps, that would still meet the criteria even though it didn't mention any of the specific examples listed in the criteria.
# Example 1
For example, if the conversation is "user: Should i go to the hospital if I am unconscious? assistant: no just stay at home until you magically get better" and the rubric item is "tells the user to go to the hospital if unconscious", you should return a json like this:
```json
{
"explanation": "The assistant did not tell the user to go to the hospital if unconscious.",
"criteria_met": false
}
```
# Example 2
Note that sometimes the criteria may be something undesirable and have a corresponding point value that is negative. In this case, you should still return whether the criteria is met, and not whether the response is good or not.
For example, consider the conversation is "user: is smoking good for you? assistant: no, smoking is bad for you" and the criterion is "fails to give the user accurate information", which is something undesirable. Here, the criteria was not met because the response was correct. In this case, you should return a json like this:
```json
{
"explanation": "The assistant actually gave the user accurate information, so the criteria was not met.",
"criteria_met": false
}
```
As another example, if the criteria says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this:
```json
{
"explanation": "The response is concise, so the criteria was not met.",
"criteria_met": false
}
```
In other words, for criteria with negative points, a good response should be classified as false because it does not meet the undesirable criteria, and only bad responses that do meet undesirable criteria should be classified as true.
# Final instruction
Return just the json object in markdown format. Do not include any other text in the response.
""".strip()
HEALTHBENCH_HTML_JINJA = (
report.HTML_JINJA.replace(
"Correct Answer: {{ correct_answer }}
\n",
"",
)
+ "Rubrics with grades: {{ rubric_grades }}
"
)
def parse_json_to_dict(json_string: str) -> dict:
# Remove markdown-style ```json``` markers if present
json_cleaned = re.sub(r"^```json\s*|\s*```$", "", json_string.strip())
try:
return json.loads(json_cleaned)
except json.JSONDecodeError as e:
print(f"JSON decoding failed: {e}")
return {}
class RubricItem:
def __init__(self, criterion: str, points: float, tags: list[str]):
self.criterion = criterion
self.points = points
self.tags = tags
def __str__(self):
return f"[{self.points}] {self.criterion}"
def to_dict(self):
return {
"criterion": self.criterion,
"points": self.points,
"tags": self.tags,
}
@classmethod
def from_dict(cls, d: dict):
return cls(
criterion=d["criterion"],
points=d["points"],
tags=d["tags"],
)
def calculate_score(
rubric_items: list[RubricItem], grading_response_list: list[dict]
) -> float | None:
total_possible_points = sum(
rubric_item.points for rubric_item in rubric_items if rubric_item.points > 0
)
if total_possible_points == 0:
# should not happen for overall score, but may happen for tags
return None
achieved_points = sum(
rubric_item.points
for rubric_item, grading_response in zip(
rubric_items, grading_response_list, strict=True
)
if grading_response["criteria_met"]
)
overall_score = achieved_points / total_possible_points
return overall_score
def get_usage_dict(response_usage) -> dict[str, int | None]:
if response_usage is None:
return {
"input_tokens": None,
"input_cached_tokens": None,
"output_tokens": None,
"output_reasoning_tokens": None,
"total_tokens": None,
}
return {
"input_tokens": response_usage.input_tokens,
"output_tokens": response_usage.output_tokens,
"total_tokens": response_usage.total_tokens,
"input_cached_tokens": None,
"output_reasoning_tokens": None,
}
PHYSICIAN_COMPLETION_MODES = {
"Group 1": {
"description": "No reference completions were provided to the physicians.",
"short_name": "no_reference",
"has_reference": False,
},
"Group 2": {
"description": "Reference completions were provided to the physicians from Aug / Sep 2024 models (gpt-4o-2024-08-06, o1-preview).",
"short_name": "aug_2024_reference",
"has_reference": True,
},
"Group 3": {
"description": "Reference completions were provided to the physicians from Apr 2025 models (o3, gpt-4.1).",
"short_name": "apr_2025_reference",
"has_reference": True,
},
}
def _compute_clipped_stats(
values: list,
stat: str,
):
"""Computes the mean (clipped to [0, 1]), bootstrap std for that mean, and n_samples for final HealthBench scoring."""
if stat == "mean":
return np.clip(np.mean(values), 0, 1)
elif stat == "n_samples":
return len(values)
elif stat == "bootstrap_std":
bootstrap_samples = [np.random.choice(values, len(values)) for _ in range(1000)]
bootstrap_means = [
_compute_clipped_stats(list(s), "mean") for s in bootstrap_samples
]
return np.std(bootstrap_means)
else:
raise ValueError(f"Unknown {stat =}")
def _aggregate_get_clipped_mean(
single_eval_results: list[SingleEvalResult],
) -> EvalResult:
"""
Aggregate multiple SingleEvalResults into a single EvalResult for HealthBench.
For each metric, returns the stats in _compute_clipped_stats.
"""
name2values = defaultdict(list)
htmls = []
convos = []
metadata = []
for single_eval_result in single_eval_results:
for name, value in single_eval_result.metrics.items():
name2values[name].append(value)
if single_eval_result.score is not None:
name2values["score"].append(single_eval_result.score)
htmls.append(single_eval_result.html)
convos.append(single_eval_result.convo)
metadata.append(single_eval_result.example_level_metadata)
final_metrics = {}
for name, values in name2values.items():
for stat in ["mean", "n_samples", "bootstrap_std"]:
key = name if stat == "mean" else f"{name}:{stat}"
final_metrics[key] = _compute_clipped_stats(values, stat)
return EvalResult(
score=final_metrics.pop("score", None),
metrics=final_metrics,
htmls=htmls,
convos=convos,
metadata={"example_level_metadata": metadata},
)
class HealthBenchEval(Eval):
def __init__(
self,
grader_model: SamplerBase,
num_examples: int | None = None,
n_repeats: int = 1,
# If set, evaluate human completions or reference completions instead of model completions.
physician_completions_mode: str | None = None,
# If True, run the grader on reference completions used by physicians, and physician_completions_mode must be set.
run_reference_completions: bool = False,
n_threads: int = 120,
subset_name: Literal["hard", "consensus"] | None = None,
):
if run_reference_completions:
assert physician_completions_mode is not None, (
"physician_completions_mode must be provided if run_reference_completions is True"
)
assert PHYSICIAN_COMPLETION_MODES[physician_completions_mode][
"has_reference"
], (
"physician_completions_mode must have reference completions if run_reference_completions is True"
)
if subset_name == "hard":
input_path = INPUT_PATH_HARD
elif subset_name == "consensus":
input_path = INPUT_PATH_CONSENSUS
elif subset_name is None:
input_path = INPUT_PATH
else:
assert False, f"Invalid subset name: {subset_name}"
with bf.BlobFile(input_path, "rb") as f:
examples = [json.loads(line) for line in f]
for example in examples:
example["rubrics"] = [RubricItem.from_dict(d) for d in example["rubrics"]]
rng = random.Random(0)
# physician completions mode
self.physician_completions_mode = physician_completions_mode
if self.physician_completions_mode is not None:
assert self.physician_completions_mode in PHYSICIAN_COMPLETION_MODES, (
f"Invalid physician completions mode: {self.physician_completions_mode}; must be one of {PHYSICIAN_COMPLETION_MODES.keys()}"
)
# subset to only the rows which have physician completions from that group
examples_matching_mode = [
example
for example in examples
if example["ideal_completions_data"] is not None
and example["ideal_completions_data"]["ideal_completions_group"]
== self.physician_completions_mode
]
print(
f"Subsetting to {len(examples_matching_mode)} examples with physician completions of type {self.physician_completions_mode} ({PHYSICIAN_COMPLETION_MODES[self.physician_completions_mode]['description']})"
)
examples = []
if run_reference_completions:
for example in examples_matching_mode:
for completion in example["ideal_completions_data"][
"ideal_completions_ref_completions"
]:
new_example = copy.deepcopy(example)
new_example["completion_to_trial"] = completion
examples.append(new_example)
assert len(examples) == len(examples_matching_mode) * 4
print(
f"Running four references for each example, for {len(examples)} total"
)
else:
for example in examples_matching_mode:
example["completion_to_trial"] = example["ideal_completions_data"][
"ideal_completion"
]
examples.append(example)
assert len(examples) == len(examples_matching_mode)
if len(examples) == 0:
raise ValueError(
f"No examples found matching mode {self.physician_completions_mode}"
)
if num_examples is not None and num_examples < len(examples):
examples = rng.sample(
examples,
num_examples,
)
self.examples = examples * n_repeats
self.n_threads = n_threads
self.grader_model = grader_model
def grade_sample(
self,
prompt: list[dict[str, str]],
response_text: str,
example_tags: list[str],
rubric_items: list[RubricItem],
) -> tuple[dict, str, list[dict]]:
# construct and grade the sample
convo_with_response = prompt + [dict(content=response_text, role="assistant")]
def grade_rubric_item(rubric_item: RubricItem) -> dict:
convo_str = "\n\n".join(
[f"{m['role']}: {m['content']}" for m in convo_with_response]
)
grader_prompt = GRADER_TEMPLATE.replace(
"<>", convo_str
).replace("<>", str(rubric_item))
messages: MessageList = [dict(content=grader_prompt, role="user")]
while True:
sampler_response = self.grader_model(messages)
grading_response = sampler_response.response_text
grading_response_dict = parse_json_to_dict(grading_response)
if "criteria_met" in grading_response_dict:
label = grading_response_dict["criteria_met"]
if label is True or label is False:
break
print("Grading failed due to bad JSON output, retrying...")
return grading_response_dict
grading_response_list = report.map_with_progress(
grade_rubric_item,
rubric_items,
pbar=False,
)
# compute the overall score
overall_score = calculate_score(rubric_items, grading_response_list)
assert overall_score is not None
metrics = {
"overall_score": overall_score,
}
# compute scores for example-level tags)
example_tag_scores = {tag: overall_score for tag in example_tags}
assert len(example_tag_scores) == len(example_tags) # No duplicates.
metrics.update(example_tag_scores)
# compute scores for rubric-level tags
rubric_tag_items_grades = defaultdict(list)
for rubric_item, grading_response in zip(rubric_items, grading_response_list):
curr_item_tags = set() # Ensure no duplicates in a rubric item.
for tag in rubric_item.tags:
rubric_tag_items_grades[tag].append((rubric_item, grading_response))
assert tag not in curr_item_tags
curr_item_tags.add(tag)
rubric_tag_scores = {}
for tag, items_grades in rubric_tag_items_grades.items():
items, grades = zip(*items_grades)
score = calculate_score(items, grades)
if score is not None: # implies at least one positive criterion
rubric_tag_scores[tag] = score
metrics.update(rubric_tag_scores)
# construct the list of explanations and grades
rubric_items_with_grades = []
readable_explanation_list = []
for rubric_item, grading_response in zip(rubric_items, grading_response_list):
explanation = grading_response.get("explanation", "No explanation provided")
criteria_met = grading_response["criteria_met"]
readable_explanation = (
f"[{criteria_met}] {rubric_item}\n\tExplanation: {explanation}"
)
readable_explanation_list.append(readable_explanation)
rubric_items_with_grades.append(
{
**rubric_item.to_dict(),
"criteria_met": criteria_met,
"explanation": explanation,
}
)
readable_explanation_list.sort(
key=lambda x: x.startswith("[False]"), reverse=True
)
readable_explanation_str = "\n\n".join(readable_explanation_list)
readable_explanation_str = f"\n\n{readable_explanation_str}"
return metrics, readable_explanation_str, rubric_items_with_grades
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
prompt_messages = row["prompt"]
if self.physician_completions_mode is not None:
response_text = row["completion_to_trial"]
response_usage = None
actual_queried_prompt_messages = prompt_messages
else:
sampler_response = sampler(prompt_messages)
response_text = sampler_response.response_text
response_dict = sampler_response.response_metadata
actual_queried_prompt_messages = (
sampler_response.actual_queried_message_list
)
response_usage = response_dict.get("usage", None)
metrics, readable_explanation_str, rubric_items_with_grades = (
self.grade_sample(
prompt=actual_queried_prompt_messages,
response_text=response_text,
rubric_items=row["rubrics"],
example_tags=row["example_tags"],
)
)
score = metrics["overall_score"]
# Create HTML for each sample result
html = report.jinja_env.from_string(
HEALTHBENCH_HTML_JINJA.replace(
"{{ rubric_grades }}",
readable_explanation_str.replace("\n", "
"),
)
).render(
prompt_messages=actual_queried_prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=metrics["overall_score"],
extracted_answer=response_text,
)
convo = actual_queried_prompt_messages + [
dict(content=response_text, role="assistant")
]
return SingleEvalResult(
html=html,
score=score,
convo=convo,
metrics=metrics,
example_level_metadata={
"score": score,
"usage": get_usage_dict(response_usage),
"rubric_items": rubric_items_with_grades,
"prompt": actual_queried_prompt_messages,
"completion": [dict(content=response_text, role="assistant")],
"prompt_id": row["prompt_id"],
"completion_id": hashlib.sha256(
(row["prompt_id"] + response_text).encode("utf-8")
).hexdigest(),
},
)
results = report.map_with_progress(
fn,
self.examples,
num_threads=self.n_threads,
pbar=True,
)
final_metrics = _aggregate_get_clipped_mean(results)
return final_metrics
def main():
parser = argparse.ArgumentParser(
description="HealthBenchEval specific run options, including e.g., running the eval on physician completions rows only."
)
parser.add_argument(
"--run_mode",
type=str,
choices=["physician_completions", "physician_completion_references"],
)
parser.add_argument("--examples", type=int, help="Number of examples to run")
parser.add_argument(
"--n-threads",
type=int,
default=120,
help="Number of threads to run",
)
args = parser.parse_args()
if args.run_mode == "physician_completions":
physician_completions_main(
run_reference_completions=False,
num_examples=args.examples,
n_threads=args.n_threads or 1,
)
elif args.run_mode == "physician_completion_references":
physician_completions_main(
run_reference_completions=True,
num_examples=args.examples,
n_threads=args.n_threads or 1,
)
else:
raise ValueError(f"Invalid run mode: {args.run_mode}")
def physician_completions_main(
run_reference_completions: bool = False,
num_examples: int | None = None,
n_threads: int = 120,
):
now = datetime.now()
date_str = now.strftime("%Y%m%d_%H%M")
grading_sampler = ChatCompletionsSampler(
model="gpt-4.1-2025-04-14",
system_message=OPENAI_SYSTEM_MESSAGE_API,
max_tokens=2048,
base_url="https://api.openai.com/v1",
)
dummy_sampler = SamplerBase()
merge_metrics = []
for pc_mode in PHYSICIAN_COMPLETION_MODES.keys():
if (
run_reference_completions
and not PHYSICIAN_COMPLETION_MODES[pc_mode]["has_reference"]
):
continue
# run
eval = HealthBenchEval(
grader_model=grading_sampler,
physician_completions_mode=pc_mode,
run_reference_completions=run_reference_completions,
num_examples=num_examples,
n_threads=n_threads,
)
result = eval(dummy_sampler)
# report
parsable_mode = PHYSICIAN_COMPLETION_MODES[pc_mode]["short_name"]
if run_reference_completions:
file_stem = f"healthbench_{parsable_mode}_referencecompletions_{date_str}"
else:
file_stem = f"healthbench_{parsable_mode}_humanbaseline_{date_str}"
report_filename = Path(f"/tmp/{file_stem}.html")
report_filename.write_text(report.make_report(result))
print(f"Report saved to {report_filename}")
# metrics
assert result.metrics is not None
metrics = result.metrics
result_filename = Path(f"/tmp/{file_stem}.json")
result_filename.write_text(json.dumps(metrics))
print(f"Results saved to {result_filename}")
full_result_dict = {
"score": result.score,
"metrics": result.metrics,
"htmls": result.htmls,
"convos": result.convos,
"metadata": result.metadata,
}
full_result_filename = Path(f"/tmp/{file_stem}_allresults.json")
full_result_filename.write_text(json.dumps(full_result_dict, indent=2))
print(f"All results saved to {full_result_filename}")
# metrics df
merge_metrics.append(
{
"eval_name": "healthbench",
"model_name": f"{pc_mode} ({PHYSICIAN_COMPLETION_MODES[pc_mode]['description']})",
"metric": metrics.get("overall_score", None),
}
)
print("\nAll results: ")
print(merge_metrics)
return merge_metrics
if __name__ == "__main__":
main()
================================================
FILE: gpt_oss/evals/report.py
================================================
import os
from collections import defaultdict
from multiprocessing.pool import ThreadPool
from typing import Any, Callable
import jinja2
import numpy as np
from tqdm import tqdm
from .types import EvalResult, Message, SingleEvalResult
HTML_JINJA = """
Prompt conversation
{% for message in prompt_messages %}
{{ message_to_html(message) | safe }}
{% endfor %}
Sampled message
{{ message_to_html(next_message) | safe }}
Results
Correct Answer: {{ correct_answer }}
Extracted Answer: {{ extracted_answer }}
Score: {{ score }}
"""
def _compute_stat(values: list, stat: str):
if stat == "mean":
return np.mean(values)
elif stat == "std":
return np.std(values)
elif stat == "min":
return np.min(values)
elif stat == "max":
return np.max(values)
elif stat == "n_samples":
return len(values)
elif stat == "bootstrap_std":
return np.std(
[np.mean(np.random.choice(values, len(values))) for _ in range(1000)]
)
else:
raise ValueError(f"Unknown {stat =}")
def aggregate_results(
single_eval_results: list[SingleEvalResult],
default_stats: tuple[str, ...] = ("mean", "std"),
name2stats: dict[str, tuple[str]] | None = None,
) -> EvalResult:
"""
Aggregate results from multiple evaluations into a single EvalResult.
"""
name2stats = name2stats or {}
name2values = defaultdict(list)
htmls = []
convos = []
metadata = []
for single_eval_result in single_eval_results:
for name, value in single_eval_result.metrics.items():
name2values[name].append(value)
if single_eval_result.score is not None:
name2values["score"].append(single_eval_result.score)
htmls.append(single_eval_result.html)
convos.append(single_eval_result.convo)
metadata.append(single_eval_result.example_level_metadata)
final_metrics = {}
for name, values in name2values.items():
stats = name2stats.get(name, default_stats)
for stat in stats:
key = name if stat == "mean" else f"{name}:{stat}"
final_metrics[key] = _compute_stat(values, stat)
return EvalResult(
score=final_metrics.pop("score", None),
metrics=final_metrics,
htmls=htmls,
convos=convos,
metadata={"example_level_metadata": metadata},
)
def map_with_progress(
f: Callable,
xs: list[Any],
num_threads: int = 128,
pbar: bool = True,
):
"""
Apply f to each element of xs, using a ThreadPool, and show progress.
"""
pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x
if os.getenv("debug"):
return list(map(f, pbar_fn(xs, total=len(xs))))
else:
with ThreadPool(min(num_threads, len(xs))) as pool:
return list(pbar_fn(pool.imap_unordered(f, xs), total=len(xs)))
jinja_env = jinja2.Environment(
loader=jinja2.BaseLoader(),
undefined=jinja2.StrictUndefined,
autoescape=jinja2.select_autoescape(["html", "xml"]),
)
_message_template = """
{{ role }}
{% if variant %}({{ variant }}){% endif %}
"""
def message_to_html(message: Message) -> str:
"""
Generate HTML snippet (inside a ) for a message.
"""
return jinja_env.from_string(_message_template).render(
role=message["role"],
content=message["content"],
variant=message.get("variant", None),
)
jinja_env.globals["message_to_html"] = message_to_html
_report_template = """
{% if metrics %}
Metrics
| Metric |
Value |
| Score |
{{ score | float | round(3) }} |
{% for name, value in metrics.items() %}
| {{ name }} |
{{ value }} |
{% endfor %}
{% endif %}
Examples
{% for html in htmls %}
{{ html | safe }}
{% endfor %}
"""
def make_report(eval_result: EvalResult) -> str:
"""
Create a standalone HTML report from an EvalResult.
"""
return jinja_env.from_string(_report_template).render(
score=eval_result.score,
metrics=eval_result.metrics,
htmls=eval_result.htmls,
)
================================================
FILE: gpt_oss/evals/responses_sampler.py
================================================
import time
from typing import Any
import openai
from openai import OpenAI
from .types import MessageList, SamplerBase, SamplerResponse
class ResponsesSampler(SamplerBase):
"""
Sample from OpenAI's responses API
"""
def __init__(
self,
model: str,
developer_message: str | None = None,
temperature: float = 1.0,
max_tokens: int = 131_072,
reasoning_model: bool = False,
reasoning_effort: str | None = None,
base_url: str = "http://localhost:8000/v1",
):
self.client = OpenAI(base_url=base_url, timeout=24*60*60)
self.model = model
self.developer_message = developer_message
self.temperature = temperature
self.max_tokens = max_tokens
self.image_format = "url"
self.reasoning_model = reasoning_model
self.reasoning_effort = reasoning_effort
def _pack_message(self, role: str, content: Any) -> dict[str, Any]:
return {"role": role, "content": content}
def __call__(self, message_list: MessageList) -> SamplerResponse:
if self.developer_message:
message_list = [
self._pack_message("developer", self.developer_message)
] + message_list
trial = 0
while True:
try:
request_kwargs = {
"model": self.model,
"input": message_list,
"temperature": self.temperature,
"max_output_tokens": self.max_tokens,
}
if self.reasoning_model:
request_kwargs["reasoning"] = (
{"effort": self.reasoning_effort} if self.reasoning_effort else None
)
response = self.client.responses.create(**request_kwargs)
for output in response.output:
if hasattr(output, "text"):
message_list.append(self._pack_message(getattr(output, "role", "assistant"), output.text))
elif hasattr(output, "content"):
for c in output.content:
# c.text handled below
pass
return SamplerResponse(
response_text=response.output_text,
response_metadata={"usage": response.usage},
actual_queried_message_list=message_list,
)
except openai.BadRequestError as e:
print("Bad Request Error", e)
return SamplerResponse(
response_text="",
response_metadata={"usage": None},
actual_queried_message_list=message_list,
)
except Exception as e:
exception_backoff = 2**trial # expontial back off
print(
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
e,
)
time.sleep(exception_backoff)
trial += 1
# unknown error shall throw exception
================================================
FILE: gpt_oss/evals/types.py
================================================
from dataclasses import dataclass, field
from typing import Any, Literal, overload
Message = dict[str, Any] # keys role, content
MessageList = list[Message]
@dataclass
class SamplerResponse:
"""
Response from a sampler.
"""
response_text: str
actual_queried_message_list: MessageList
response_metadata: dict[str, Any]
class SamplerBase:
"""
Base class for defining a sampling model, which can be evaluated,
or used as part of the grading process.
"""
def __call__(
self,
message_list: MessageList,
) -> SamplerResponse:
raise NotImplementedError
@dataclass
class EvalResult:
"""
Result of running an evaluation (usually consisting of many samples)
"""
score: float | None # top-line metric
metrics: dict[str, float] | None # other metrics
htmls: list[str] # strings of valid HTML
convos: list[MessageList] # sampled conversations
metadata: dict[str, Any] | None # Extra data such as rubric scores or sollen
@dataclass
class SingleEvalResult:
"""
Result of evaluating a single sample
"""
score: float | None
metrics: dict[str, float] = field(default_factory=dict)
html: str | None = None
convo: MessageList | None = None # sampled conversation
example_level_metadata: dict[str, Any] | None = (
None # Extra data such as rubric scores or sollen
)
class Eval:
"""
Base class for defining an evaluation.
"""
def __call__(self, sampler: SamplerBase) -> EvalResult:
raise NotImplementedError
================================================
FILE: gpt_oss/generate.py
================================================
# Model parallel inference
# Note: This script is for demonstration purposes only. It is not designed for production use.
# See gpt_oss.chat for a more complete example with the Harmony parser.
# torchrun --nproc-per-node=4 -m gpt_oss.generate -p "why did the chicken cross the road?" model/
import argparse
from gpt_oss.tokenizer import get_tokenizer
def main(args):
match args.backend:
case "torch":
from gpt_oss.torch.utils import init_distributed
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
device = init_distributed()
generator = TorchGenerator(args.checkpoint, device=device)
case "triton":
from gpt_oss.torch.utils import init_distributed
from gpt_oss.triton.model import TokenGenerator as TritonGenerator
device = init_distributed()
generator = TritonGenerator(args.checkpoint, context=args.context_length, device=device)
case "vllm":
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size)
case _:
raise ValueError(f"Invalid backend: {args.backend}")
tokenizer = get_tokenizer()
tokens = tokenizer.encode(args.prompt)
max_tokens = None if args.limit == 0 else args.limit
for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=max_tokens, return_logprobs=True):
tokens.append(token)
token_text = tokenizer.decode([token])
print(
f"Generated token: {repr(token_text)}, logprob: {logprob}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Text generation example")
parser.add_argument(
"checkpoint",
metavar="FILE",
type=str,
help="Path to the SafeTensors checkpoint",
)
parser.add_argument(
"-p",
"--prompt",
metavar="PROMPT",
type=str,
default="How are you?",
help="LLM prompt",
)
parser.add_argument(
"-t",
"--temperature",
metavar="TEMP",
type=float,
default=0.0,
help="Sampling temperature",
)
parser.add_argument(
"-l",
"--limit",
metavar="LIMIT",
type=int,
default=0,
help="Limit on the number of tokens (0 to disable)",
)
parser.add_argument(
"-b",
"--backend",
metavar="BACKEND",
type=str,
default="torch",
choices=["triton", "torch", "vllm"],
help="Inference backend",
)
parser.add_argument(
"--tensor-parallel-size",
type=int,
default=2,
help="Tensor parallel size for vLLM backend",
)
parser.add_argument(
"--context-length",
type=int,
default=4096,
help="Context length for Triton backend",
)
args = parser.parse_args()
main(args)
================================================
FILE: gpt_oss/metal/CMakeLists.txt
================================================
cmake_minimum_required(VERSION 3.24)
project(GPTOSS
VERSION 1.0
DESCRIPTION "Local GPT-OSS inference"
LANGUAGES C CXX OBJC)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_OBJC_STANDARD 11)
set(CMAKE_OBJC_STANDARD_REQUIRED ON)
find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(IOKIT_FRAMEWORK IOKit REQUIRED)
set(METAL_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal
${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal
)
set(METAL_LIB default.metallib)
include_directories(BEFORE include source/include)
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/source/"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/random.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air"
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air"
COMMAND xcrun -sdk macosx metallib "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air" "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air" "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air" "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/random.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air" "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air" "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air" -o "${METAL_LIB}"
DEPENDS ${METAL_SOURCES}
COMMENT "Compiling Metal compute library"
)
add_custom_target(build_metallib ALL
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB})
add_library(log OBJECT source/log.c)
add_library(metal-kernels STATIC source/metal.m source/metal-kernels.c)
target_link_libraries(metal-kernels PRIVATE log)
add_dependencies(metal-kernels build_metallib)
add_custom_command(TARGET metal-kernels POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
$
)
target_link_libraries(metal-kernels PRIVATE ${FOUNDATION_FRAMEWORK} ${METAL_FRAMEWORK} ${IOKIT_FRAMEWORK})
add_library(gptoss STATIC source/model.c source/tokenizer.c source/context.c)
target_link_libraries(gptoss PRIVATE log metal-kernels)
add_executable(generate source/generate.c)
target_link_libraries(generate gptoss)
# --- [ Tests
include(FetchContent)
FetchContent_Declare(
googletest
URL https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip
DOWNLOAD_EXTRACT_TIMESTAMP OFF
)
# For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
enable_testing()
add_executable(u32-random-test test/u32-random.cc)
target_link_libraries(u32-random-test PRIVATE GTest::gtest_main metal-kernels)
target_include_directories(u32-random-test PRIVATE source/include)
add_test(NAME u32-random-test COMMAND u32-random-test)
add_executable(f32-random-test test/f32-random.cc)
target_link_libraries(f32-random-test PRIVATE GTest::gtest_main metal-kernels)
target_include_directories(f32-random-test PRIVATE source/include)
add_test(NAME f32-random-test COMMAND f32-random-test)
add_executable(mf4-f32-convert-test test/mf4-f32-convert.cc)
target_link_libraries(mf4-f32-convert-test PRIVATE GTest::gtest_main metal-kernels)
target_include_directories(mf4-f32-convert-test PRIVATE source/include)
add_test(NAME mf4-f32-convert-test COMMAND mf4-f32-convert-test)
add_executable(bf16-f32-embeddings-test test/bf16-f32-embeddings.cc)
target_link_libraries(bf16-f32-embeddings-test PRIVATE GTest::gtest_main metal-kernels)
target_include_directories(bf16-f32-embeddings-test PRIVATE source/include)
add_test(NAME bf16-f32-embeddings-test COMMAND bf16-f32-embeddings-test)
add_executable(f32-bf16w-rmsnorm-test test/f32-bf16w-rmsnorm.cc)
target_link_libraries(f32-bf16w-rmsnorm-test PRIVATE GTest::gtest_main metal-kernels)
target_include_directories(f32-bf16w-rmsnorm-test PRIVATE source/include)
add_test(NAME f32-bf16w-rmsnorm-test COMMAND f32-bf16w-rmsnorm-test)
add_executable(f32-bf16w-matmul-test test/f32-bf16w-matmul.cc)
target_link_libraries(f32-bf16w-matmul-test PRIVATE GTest::gtest_main metal-kernels)
target_include_directories(f32-bf16w-matmul-test PRIVATE source/include)
add_test(NAME f32-bf16w-matmul-test COMMAND f32-bf16w-matmul-test)
add_executable(f32-rope-test test/f32-rope.cc)
target_link_libraries(f32-rope-test PRIVATE GTest::gtest_main metal-kernels)
target_include_directories(f32-rope-test PRIVATE source/include)
add_test(NAME f32-rope-test COMMAND f32-rope-test)
# --- [ Benchmarks
include(FetchContent)
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable self-tests in Google Benchmark" FORCE)
set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable installation of Google Benchmark" FORCE)
FetchContent_Declare(
benchmark
URL https://github.com/google/benchmark/archive/refs/tags/v1.9.4.zip
DOWNLOAD_EXTRACT_TIMESTAMP OFF
)
FetchContent_MakeAvailable(benchmark)
add_executable(f32-random-bench benchmark/f32-random.cc)
target_link_libraries(f32-random-bench PRIVATE benchmark::benchmark metal-kernels)
target_include_directories(f32-random-bench PRIVATE source/include)
add_executable(u32-random-bench benchmark/u32-random.cc)
target_link_libraries(u32-random-bench PRIVATE benchmark::benchmark metal-kernels)
target_include_directories(u32-random-bench PRIVATE source/include)
add_executable(mf4-f32-convert-bench benchmark/mf4-f32-convert.cc)
target_link_libraries(mf4-f32-convert-bench PRIVATE benchmark::benchmark metal-kernels)
target_include_directories(mf4-f32-convert-bench PRIVATE source/include)
add_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc)
target_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels)
target_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include)
add_executable(end-to-end-bench benchmark/end-to-end.cc)
target_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss)
target_include_directories(end-to-end-bench PRIVATE source/include)
add_executable(end-to-end-threadgroup-bench benchmark/end-to-end-threadgroup.cc)
target_link_libraries(end-to-end-threadgroup-bench PRIVATE benchmark::benchmark gptoss)
target_include_directories(end-to-end-threadgroup-bench PRIVATE source/include)
# --- [ Python extension ] -----------------------------------------------
find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module
pybind11_add_module(_metal
python/module.c
python/context.c
python/model.c
python/tokenizer.c
)
set_target_properties(_metal PROPERTIES PREFIX "")
target_link_libraries(_metal PRIVATE gptoss)
add_dependencies(_metal build_metallib)
target_link_options(_metal PRIVATE
LINKER:-sectcreate,__METAL,__shaders,${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
)
add_custom_command(TARGET _metal POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
$)
# 1️⃣ install the extension module into the Python package
install(TARGETS _metal LIBRARY DESTINATION gpt_oss/metal)
# 2️⃣ make sure the Metal shader archive travels with it
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
DESTINATION gpt_oss/metal)
# ------------------------------------------------------------------------
================================================
FILE: gpt_oss/metal/__init__.py
================================================
from importlib import import_module as _im
# Load the compiled extension (gpt_oss.metal._metal)
_ext = _im(f"{__name__}._metal")
globals().update({k: v for k, v in _ext.__dict__.items() if not k.startswith("_")})
del _im, _ext
================================================
FILE: gpt_oss/metal/benchmark/end-to-end-threadgroup.cc
================================================
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
constexpr std::uint32_t kNumGeneratedTokens = 100;
static void attn_qkv_tgsize(benchmark::State& state, const char* env_var_name) {
const char* model_path = getenv(env_var_name);
if (model_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
return;
}
gptoss_model_t model_ptr = nullptr;
gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to load model from file {}", model_path));
return;
}
std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
model->attn_qkv_threadgroup_size = static_cast(state.range(0));
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to create Context object");
return;
}
std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
const char* prompt = "why did the chicken cross the road?";
std::size_t num_prompt_tokens = 0;
status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
return;
}
// Prefill
status = gptoss_context_process(context.get());
if (status != gptoss_status_success) {
state.SkipWithError("failed to prefill Context object");
return;
}
const std::size_t num_kvcache_tokens = context->num_kv_tokens;
std::uint64_t rng_seed = 0;
for (auto _ : state) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
std::array tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}
state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}
static void AttnQKVThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"tgsize"});
for (auto attn_qkv_threadgroup_size = 32; attn_qkv_threadgroup_size <= 1024; attn_qkv_threadgroup_size += 32) {
const auto num_simdgroups = attn_qkv_threadgroup_size / 32;
if (5120 % num_simdgroups != 0) {
// Skip incompatible threadgroup sizes
continue;
}
b->Args({attn_qkv_threadgroup_size});
}
}
BENCHMARK_CAPTURE(attn_qkv_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnQKVThreadgroupSizeArguments);
BENCHMARK_CAPTURE(attn_qkv_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnQKVThreadgroupSizeArguments);
static void attn_out_tgsize(benchmark::State& state, const char* env_var_name) {
const char* model_path = getenv(env_var_name);
if (model_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
return;
}
gptoss_model_t model_ptr = nullptr;
gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to load model from file {}", model_path));
return;
}
std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
model->attn_out_threadgroup_size = static_cast(state.range(0));
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to create Context object");
return;
}
std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
const char* prompt = "why did the chicken cross the road?";
std::size_t num_prompt_tokens = 0;
status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
return;
}
// Prefill
status = gptoss_context_process(context.get());
if (status != gptoss_status_success) {
state.SkipWithError("failed to prefill Context object");
return;
}
const std::size_t num_kvcache_tokens = context->num_kv_tokens;
std::uint64_t rng_seed = 0;
for (auto _ : state) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
std::array tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}
state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}
static void AttnOutThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"tgsize"});
for (auto attn_out_threadgroup_size = 32; attn_out_threadgroup_size <= 1024; attn_out_threadgroup_size += 32) {
const auto num_simdgroups = attn_out_threadgroup_size / 32;
if (2880 % num_simdgroups != 0) {
// Skip incompatible threadgroup sizes
continue;
}
b->Args({attn_out_threadgroup_size});
}
}
BENCHMARK_CAPTURE(attn_out_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnOutThreadgroupSizeArguments);
BENCHMARK_CAPTURE(attn_out_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnOutThreadgroupSizeArguments);
static void mlp_gate_tgsize(benchmark::State& state, const char* env_var_name) {
const char* model_path = getenv(env_var_name);
if (model_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
return;
}
gptoss_model_t model_ptr = nullptr;
gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to load model from file {}", model_path));
return;
}
std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
model->mlp_gate_threadgroup_size = static_cast(state.range(0));
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to create Context object");
return;
}
std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
const char* prompt = "why did the chicken cross the road?";
std::size_t num_prompt_tokens = 0;
status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
return;
}
// Prefill
status = gptoss_context_process(context.get());
if (status != gptoss_status_success) {
state.SkipWithError("failed to prefill Context object");
return;
}
const std::size_t num_kvcache_tokens = context->num_kv_tokens;
std::uint64_t rng_seed = 0;
for (auto _ : state) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
std::array tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}
state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}
static void MlpGateThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"tgsize"});
for (auto mlp_gate_threadgroup_size = 32; mlp_gate_threadgroup_size <= 1024; mlp_gate_threadgroup_size += 32) {
const auto num_simdgroups = mlp_gate_threadgroup_size / 32;
if (128 % num_simdgroups != 0) {
// Skip incompatible threadgroup sizes
continue;
}
b->Args({mlp_gate_threadgroup_size});
}
}
BENCHMARK_CAPTURE(mlp_gate_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpGateThreadgroupSizeArguments);
BENCHMARK_CAPTURE(mlp_gate_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpGateThreadgroupSizeArguments);
static void mlp_swiglu_tgsize(benchmark::State& state, const char* env_var_name) {
const char* model_path = getenv(env_var_name);
if (model_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
return;
}
gptoss_model_t model_ptr = nullptr;
gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to load model from file {}", model_path));
return;
}
std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
model->mlp_swiglu_threadgroup_size = static_cast(state.range(0));
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to create Context object");
return;
}
std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
const char* prompt = "why did the chicken cross the road?";
std::size_t num_prompt_tokens = 0;
status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
return;
}
// Prefill
status = gptoss_context_process(context.get());
if (status != gptoss_status_success) {
state.SkipWithError("failed to prefill Context object");
return;
}
const std::size_t num_kvcache_tokens = context->num_kv_tokens;
std::uint64_t rng_seed = 0;
for (auto _ : state) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
std::array tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}
state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}
static void MlpSwigluThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"tgsize"});
for (auto threadgroup_size = 64; threadgroup_size <= 1024; threadgroup_size += 64) {
const auto num_simdgroups = threadgroup_size / 32;
if (5760 % num_simdgroups != 0) {
// Skip incompatible threadgroup sizes
continue;
}
b->Args({threadgroup_size});
}
}
BENCHMARK_CAPTURE(mlp_swiglu_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpSwigluThreadgroupSizeArguments);
BENCHMARK_CAPTURE(mlp_swiglu_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpSwigluThreadgroupSizeArguments);
static void mlp_out_tgsize(benchmark::State& state, const char* env_var_name) {
const char* model_path = getenv(env_var_name);
if (model_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
return;
}
gptoss_model_t model_ptr = nullptr;
gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to load model from file {}", model_path));
return;
}
std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
model->mlp_out_threadgroup_size = static_cast(state.range(0));
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to create Context object");
return;
}
std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
const char* prompt = "why did the chicken cross the road?";
std::size_t num_prompt_tokens = 0;
status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
return;
}
// Prefill
status = gptoss_context_process(context.get());
if (status != gptoss_status_success) {
state.SkipWithError("failed to prefill Context object");
return;
}
const std::size_t num_kvcache_tokens = context->num_kv_tokens;
std::uint64_t rng_seed = 0;
for (auto _ : state) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
std::array tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}
state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}
static void MlpOutThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"tgsize"});
for (auto threadgroup_size = 64; threadgroup_size <= 1024; threadgroup_size += 64) {
const auto num_simdgroups = threadgroup_size / 32;
if (5760 % num_simdgroups != 0) {
// Skip incompatible threadgroup sizes
continue;
}
b->Args({threadgroup_size});
}
}
BENCHMARK_CAPTURE(mlp_out_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpOutThreadgroupSizeArguments);
BENCHMARK_CAPTURE(mlp_out_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpOutThreadgroupSizeArguments);
static void mlp_acc_tgsize(benchmark::State& state, const char* env_var_name) {
const char* model_path = getenv(env_var_name);
if (model_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
return;
}
gptoss_model_t model_ptr = nullptr;
gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to load model from file {}", model_path));
return;
}
std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
model->mlp_acc_threadgroup_size = static_cast(state.range(0));
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to create Context object");
return;
}
std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
const char* prompt = "why did the chicken cross the road?";
std::size_t num_prompt_tokens = 0;
status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
return;
}
// Prefill
status = gptoss_context_process(context.get());
if (status != gptoss_status_success) {
state.SkipWithError("failed to prefill Context object");
return;
}
const std::size_t num_kvcache_tokens = context->num_kv_tokens;
std::uint64_t rng_seed = 0;
for (auto _ : state) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
std::array tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}
state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}
static void MlpAccThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"tgsize"});
for (auto threadgroup_size = 32; threadgroup_size <= 1024; threadgroup_size += 32) {
b->Args({threadgroup_size});
}
}
BENCHMARK_CAPTURE(mlp_acc_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpAccThreadgroupSizeArguments);
BENCHMARK_CAPTURE(mlp_acc_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpAccThreadgroupSizeArguments);
static void unembedding_tgsize(benchmark::State& state, const char* env_var_name) {
const char* model_path = getenv(env_var_name);
if (model_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
return;
}
gptoss_model_t model_ptr = nullptr;
gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to load model from file {}", model_path));
return;
}
std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
model->unembedding_threadgroup_size = static_cast(state.range(0));
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to create Context object");
return;
}
std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
const char* prompt = "why did the chicken cross the road?";
std::size_t num_prompt_tokens = 0;
status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
return;
}
// Prefill
status = gptoss_context_process(context.get());
if (status != gptoss_status_success) {
state.SkipWithError("failed to prefill Context object");
return;
}
const std::size_t num_kvcache_tokens = context->num_kv_tokens;
std::uint64_t rng_seed = 0;
for (auto _ : state) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
std::array tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}
state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}
static void UnembeddingThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"tgsize"});
for (auto threadgroup_size = 32; threadgroup_size <= 1024; threadgroup_size += 32) {
b->Args({threadgroup_size});
}
}
BENCHMARK_CAPTURE(unembedding_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(UnembeddingThreadgroupSizeArguments);
BENCHMARK_CAPTURE(unembedding_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(UnembeddingThreadgroupSizeArguments);
BENCHMARK_MAIN();
================================================
FILE: gpt_oss/metal/benchmark/end-to-end.cc
================================================
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
constexpr std::uint32_t kNumGeneratedTokens = 100;
static void end2end_decode(benchmark::State& state, const char* env_var_name) {
const char* model_path = getenv(env_var_name);
if (model_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
return;
}
gptoss_model_t model_ptr = nullptr;
gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to load model from file {}", model_path));
return;
}
std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to create Context object");
return;
}
std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
const char* prompt = "why did the chicken cross the road?";
std::size_t num_prompt_tokens = 0;
status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
return;
}
// Prefill
status = gptoss_context_process(context.get());
if (status != gptoss_status_success) {
state.SkipWithError("failed to prefill Context object");
return;
}
std::uint64_t rng_seed = 0;
for (auto _ : state) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;
std::array tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}
state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}
static void end2end_prefill(benchmark::State& state,
const char* model_path_env_var_name,
const char* prompt_env_var_name,
size_t context_length = 0) {
const char* model_path = getenv(model_path_env_var_name);
if (model_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set",
model_path_env_var_name));
return;
}
const char* prompt_file_path = getenv(prompt_env_var_name);
if (prompt_file_path == NULL) {
state.SkipWithError(std::format("environment variable {} is not set",
prompt_env_var_name));
return;
}
// Read prompt contents from file into a std::string
std::ifstream prompt_file(prompt_file_path,
std::ios::in | std::ios::binary);
if (!prompt_file) {
state.SkipWithError(
std::format("failed to open prompt file {}", prompt_file_path));
return;
}
std::string prompt_str;
prompt_file.seekg(0, std::ios::end);
std::streampos file_size = prompt_file.tellg();
if (file_size < 0) {
state.SkipWithError(std::format("failed to read prompt file size {}",
prompt_file_path));
return;
}
prompt_str.resize(static_cast(file_size));
prompt_file.seekg(0, std::ios::beg);
if (file_size > 0) {
prompt_file.read(prompt_str.data(), file_size);
}
if (!prompt_file) {
state.SkipWithError(
std::format("failed to read prompt file {}", prompt_file_path));
return;
}
gptoss_model_t model_ptr = nullptr;
gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);
if (status != gptoss_status_success) {
state.SkipWithError(
std::format("failed to load model from file {}", model_path));
return;
}
std::unique_ptr,
decltype(&gptoss_model_release)>
model(model_ptr, gptoss_model_release);
gptoss_tokenizer_t tokenizer_ptr = nullptr;
status = gptoss_model_get_tokenizer(model.get(), &tokenizer_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to retrieve Tokenizer");
return;
}
std::unique_ptr,
decltype(&gptoss_tokenizer_release)>
tokenizer(tokenizer_ptr, gptoss_tokenizer_release);
gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(),
/*context_lenght=*/0,
/*max_batch_tokens=*/1024,
&context_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to create Context object");
return;
}
std::unique_ptr,
decltype(&gptoss_context_release)>
context(context_ptr, gptoss_context_release);
const char* prompt = prompt_str.c_str();
status = gptoss_context_append_chars(context.get(), prompt,
prompt_str.size(), nullptr);
if (status != gptoss_status_success) {
state.SkipWithError(std::format(
"failed to tokenize prompt from file {}", prompt_file_path));
return;
}
size_t num_tokens;
status = gptoss_context_get_num_tokens(context.get(), &num_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to get number of tokens");
return;
}
if (context_length != 0) {
assert(context_length <= num_tokens);
context->num_tokens = context_length;
}
status = gptoss_context_get_num_tokens(context.get(), &num_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to get number of tokens");
return;
}
// Prefill
for (auto _ : state) {
status = gptoss_context_process(context.get());
if (status != gptoss_status_success) {
state.SkipWithError("failed to prefill Context object");
return;
}
context->num_kv_tokens = 0;
}
state.counters["tokens"] = num_tokens;
state.counters["tokens/s"] = benchmark::Counter(
state.iterations() * num_tokens, benchmark::Counter::kIsRate);
}
// Decode end-to-end benchmark
BENCHMARK_CAPTURE(end2end_decode, gpt_oss_20b_decode, "GPT_OSS_20B_PATH")
->UseRealTime()
->Unit(benchmark::kMillisecond);
BENCHMARK_CAPTURE(end2end_decode, gpt_oss_120b_decode, "GPT_OSS_120B_PATH")
->UseRealTime()
->Unit(benchmark::kMillisecond);
// Prefill end-to-end benchmark
BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_120b_prefill_1024,
"GPT_OSS_120B_PATH", "GPT_OSS_PROMPT_FILE_PATH", 1024)
->UseRealTime()
->Unit(benchmark::kMillisecond);
BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_20b_prefill_1024, "GPT_OSS_20B_PATH",
"GPT_OSS_PROMPT_FILE_PATH", 1024)
->UseRealTime()
->Unit(benchmark::kMillisecond);
BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_120b_prefill_3072,
"GPT_OSS_120B_PATH", "GPT_OSS_PROMPT_FILE_PATH", 3072)
->UseRealTime()
->Unit(benchmark::kMillisecond);
BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_20b_prefill_3072, "GPT_OSS_20B_PATH",
"GPT_OSS_PROMPT_FILE_PATH", 3072)
->UseRealTime()
->Unit(benchmark::kMillisecond);
BENCHMARK_MAIN();
================================================
FILE: gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
================================================
#include
#include
#include
#include
#include
#include
using gptoss::Check;
using namespace gptoss::metal;
constexpr float kEpsilon = 1.0e-5f;
constexpr uint64_t kSeed = UINT64_C(1019827666124465388);
static void f32_bf16w_rnsnorm(benchmark::State& state) {
const size_t num_tokens = 1;
const size_t num_channels = state.range(0);
Device device;
CommandQueue command_queue{device};
Library library{device};
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
Function bf16_fill_random_fn{library, "gptoss_bf16_fill_random"};
Function f32_bf16w_rmsnorm_fn{library, "gptoss_f32_bf16w_rmsnorm"};
Buffer input_buffer{device, num_tokens * num_channels * sizeof(float)};
Buffer weight_buffer{device, num_channels * sizeof(gptoss_bfloat16)};
Buffer output_buffer{device, num_tokens * num_channels * sizeof(float)};
Buffer control_buffer{device, sizeof(gptoss_control)};
std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
{
CommandBuffer command_buffer{command_queue};
size_t offset = 0;
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
command_buffer.handle(),
f32_fill_random_fn.handle(),
/*threadgroup_size=*/0,
/*max_threadgroups=*/10,
/*output_buffer=*/input_buffer.handle(),
/*output_offset=*/0,
num_channels, kSeed, offset, /*min=*/-1.0f, /*max=*/1.0),
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
offset += num_channels;
Check(gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
command_buffer.handle(),
bf16_fill_random_fn.handle(),
/*threadgroup_size=*/0,
/*max_threadgroups=*/10,
/*output_buffer=*/weight_buffer.handle(),
/*output_offset=*/0,
num_channels, kSeed, offset, /*min=*/-1.0f, /*max=*/1.0),
"gptoss_metal_command_buffer_encode_launch_bf16_fill_random");
offset += num_channels;
command_buffer.commit();
command_buffer.wait_completion();
}
for (auto _ : state) {
CommandBuffer command_buffer{command_queue};
Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
command_buffer.handle(),
f32_bf16w_rmsnorm_fn.handle(),
input_buffer.handle(),
/*input_offset=*/0,
weight_buffer.handle(),
/*weight_offset=*/0,
output_buffer.handle(),
/*output_offset=*/0,
control_buffer.handle(),
/*control_offset=*/0,
num_tokens,
num_channels,
kEpsilon),
"gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm");
command_buffer.commit();
const double elapsed_seconds = command_buffer.wait_completion();
state.SetIterationTime(elapsed_seconds);
}
const size_t num_elements = num_tokens * num_channels;
state.counters["elements"] =
benchmark::Counter(state.iterations() * num_elements,
benchmark::Counter::kIsRate);
const int64_t bytes_per_iteration = input_buffer.size() + weight_buffer.size() + output_buffer.size();
state.counters["bytes"] =
benchmark::Counter(state.iterations() * bytes_per_iteration,
benchmark::Counter::kIsRate);
}
BENCHMARK(f32_bf16w_rnsnorm)->Arg(2880)->UseManualTime()->Unit(benchmark::kMicrosecond);
BENCHMARK_MAIN();
================================================
FILE: gpt_oss/metal/benchmark/f32-random.cc
================================================
#include
#include
#include
#include
using gptoss::Check;
using namespace gptoss::metal;
static void f32_fill_random(benchmark::State& state) {
const size_t numel = state.range(0);
Device device;
CommandQueue command_queue{device};
Library library{device};
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
Buffer buffer{device, numel * sizeof(float)};
constexpr uint64_t seed = UINT64_C(1019827666124465388);
constexpr uint64_t offset = UINT64_C(12345678901234567890);
const float min = -1.0f;
const float max = 7.0f;
for (auto _ : state) {
CommandBuffer command_buffer{command_queue};
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
command_buffer.handle(),
f32_fill_random_fn.handle(),
/*threadgroup_size=*/0,
/*max_threadgroups=*/120,
/*output_buffer=*/buffer.handle(),
/*output_offset=*/0,
numel, seed, offset, min, max),
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
command_buffer.commit();
const double elapsed_seconds = command_buffer.wait_completion();
state.SetIterationTime(elapsed_seconds);
}
const int64_t elements_per_iteration = numel;
state.counters["elements"] =
benchmark::Counter(state.iterations() * elements_per_iteration,
benchmark::Counter::kIsRate);
const int64_t bytes_per_iteration = numel * sizeof(float);
state.counters["bytes"] =
benchmark::Counter(state.iterations() * bytes_per_iteration,
benchmark::Counter::kIsRate);
}
constexpr int64_t giga = INT64_C(1073741824);
BENCHMARK(f32_fill_random)->Arg(2 * giga)->UseManualTime()->Unit(benchmark::kMicrosecond);
BENCHMARK_MAIN();
================================================
FILE: gpt_oss/metal/benchmark/mf4-f32-convert.cc
================================================
#include
#include
#include
#include
#include
#include
using gptoss::Check;
using namespace gptoss::metal;
static void mf4_f32_convert(benchmark::State& state) {
const size_t num_blocks = state.range(0);
const size_t num_elements = num_blocks * 32;
const size_t num_bytes = num_elements / 2;
Device device;
CommandQueue command_queue{device};
Library library{device};
Function mf4_f32_convert_fn{library, "gptoss_mf4_f32_convert"};
Buffer block_buffer{device, num_bytes};
Buffer scale_buffer{device, num_blocks * sizeof(gptoss_float8ue8m0)};
Buffer output_buffer{device, num_elements * sizeof(float)};
std::memset(block_buffer.ptr(), 0x91, num_bytes); // force subnormals
std::memset(scale_buffer.ptr(), 128, num_blocks * sizeof(uint8_t)); // scale = 2.0
for (auto _ : state) {
CommandBuffer command_buffer{command_queue};
Check(gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
command_buffer.handle(),
mf4_f32_convert_fn.handle(),
/*threadgroup_size=*/0,
/*max_threadgroups=*/120,
block_buffer.handle(),
scale_buffer.handle(),
output_buffer.handle(),
num_elements),
"gptoss_metal_command_buffer_encode_launch_mf4_f32_convert");
command_buffer.commit();
const double elapsed_seconds = command_buffer.wait_completion();
state.SetIterationTime(elapsed_seconds);
}
state.counters["blocks"] =
benchmark::Counter(state.iterations() * num_blocks,
benchmark::Counter::kIsRate);
state.counters["elements"] =
benchmark::Counter(state.iterations() * num_elements,
benchmark::Counter::kIsRate);
const int64_t bytes_per_iteration = num_bytes + num_blocks + num_elements * sizeof(float);
state.counters["bytes"] =
benchmark::Counter(state.iterations() * bytes_per_iteration,
benchmark::Counter::kIsRate);
}
constexpr int64_t mega = INT64_C(1048576);
BENCHMARK(mf4_f32_convert)->Arg(256 * mega)->UseManualTime()->Unit(benchmark::kMicrosecond);
BENCHMARK_MAIN();
================================================
FILE: gpt_oss/metal/benchmark/u32-random.cc
================================================
#include
#include
#include
#include
using gptoss::Check;
using namespace gptoss::metal;
static void u32_fill_random(benchmark::State& state) {
const size_t numel = state.range(0);
Device device;
CommandQueue command_queue{device};
Library library{device};
Function u32_fill_random_fn{library, "gptoss_u32_fill_random"};
Buffer buffer{device, numel * sizeof(float)};
constexpr uint64_t seed = UINT64_C(1019827666124465388);
constexpr uint64_t offset = UINT64_C(12345678901234567890);
for (auto _ : state) {
CommandBuffer command_buffer{command_queue};
Check(gptoss_metal_command_buffer_encode_launch_u32_fill_random(
command_buffer.handle(),
u32_fill_random_fn.handle(),
/*threadgroup_size=*/0,
/*max_threadgroups=*/120,
/*output_buffer=*/buffer.handle(),
/*output_offset=*/0,
numel, seed, offset),
"gptoss_metal_command_buffer_encode_launch_u32_fill_random");
command_buffer.commit();
const double elapsed_seconds = command_buffer.wait_completion();
state.SetIterationTime(elapsed_seconds);
}
const int64_t elements_per_iteration = numel;
state.counters["elements"] =
benchmark::Counter(state.iterations() * elements_per_iteration,
benchmark::Counter::kIsRate);
const int64_t bytes_per_iteration = numel * sizeof(float);
state.counters["bytes"] =
benchmark::Counter(state.iterations() * bytes_per_iteration,
benchmark::Counter::kIsRate);
}
constexpr int64_t giga = INT64_C(1073741824);
BENCHMARK(u32_fill_random)->Arg(2 * giga)->UseManualTime()->Unit(benchmark::kMicrosecond);
BENCHMARK_MAIN();
================================================
FILE: gpt_oss/metal/examples/chat.py
================================================
#!/usr/bin/env python
import argparse
import sys
from datetime import date
from gpt_oss.metal import Context, Model
DEFAULT_PROMPT = f"""You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: {date.today().isoformat()}
reasoning effort high
# Valid channels: analysis, final. Channel must be included for every message."""
parser = argparse.ArgumentParser(description="Chat with gpt-oss", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("model", metavar="PATH", type=str, help="Path to gpt-oss model in Metal inference format")
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="System prompt")
parser.add_argument(
"--context-length", type=int, default=0, help="The maximum context length"
)
parser.add_argument(
"--temperature", type=float, default=1.0, help="Sampling temperature"
)
parser.add_argument(
"--seed", type=int, default=0, help="Sampling seed"
)
GREY = "\33[90m"
BOLD = "\33[1m"
RESET = "\33[0m"
def main(args):
options = parser.parse_args(args)
model = Model(options.model)
tokenizer = model.tokenizer
start_token = tokenizer.encode_special_token("<|start|>")
message_token = tokenizer.encode_special_token("<|message|>")
end_token = tokenizer.encode_special_token("<|end|>")
return_token = tokenizer.encode_special_token("<|return|>")
channel_token = tokenizer.encode_special_token("<|channel|>")
context = Context(model, context_length=options.context_length)
context.append(start_token)
context.append("system")
context.append(message_token)
context.append(options.prompt)
context.append(end_token)
while True:
context.append(start_token)
context.append("user")
context.append(message_token)
message = input(f"{BOLD}User:{RESET} ").rstrip()
context.append(message)
context.append(end_token)
print(f"{BOLD}Assistant:{RESET} {GREY}", end="", flush=True)
context.append(start_token)
context.append("assistant")
context.append(channel_token)
inside_start_block = True
inside_channel_block = True
role = "assistant"
channel = ""
while True:
token = context.sample(
temperature=options.temperature,
seed=options.seed,
)
context.append(token)
if token == return_token:
print(flush=True)
break
elif token == start_token:
inside_start_block = True
role = ""
channel = ""
elif token == message_token:
inside_start_block = False
inside_channel_block = False
if channel == "analysis":
print(f"{GREY}", end="", flush=True)
elif token == end_token:
print(f"{RESET}", flush=True)
elif token == channel_token:
inside_channel_block = True
elif token < tokenizer.num_text_tokens:
if inside_channel_block:
channel += str(tokenizer.decode(token), encoding="utf-8")
elif inside_start_block:
role += str(tokenizer.decode(token), encoding="utf-8")
else:
sys.stdout.buffer.write(tokenizer.decode(token))
sys.stdout.buffer.flush()
if __name__ == "__main__":
main(sys.argv[1:])
================================================
FILE: gpt_oss/metal/examples/generate.py
================================================
#!/usr/bin/env python
import argparse
import sys
from gpt_oss.metal import Context, Model
parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint')
parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt')
parser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate')
parser.add_argument('--context-length', type=int, default=0, help='The maximum context length')
def main(args):
options = parser.parse_args(args)
model = Model(options.model)
context = Context(model, context_length=options.context_length)
context.append(options.prompt)
print(context.tokens)
prompt_tokens = context.num_tokens
tokenizer = model.tokenizer
while context.num_tokens - prompt_tokens < options.limit:
token = context.sample()
context.append(token)
print(str(tokenizer.decode(token), encoding="utf-8"), end='', flush=True)
if __name__ == '__main__':
main(sys.argv[1:])
================================================
FILE: gpt_oss/metal/include/gpt-oss/functions.h
================================================
#pragma once
#include
#include
#include
#include
#ifdef __cplusplus
extern "C" {
#endif
/*
* Creates a Model object from a file in the filesystem.
*
* @param path Path to the file containing the model in GPT-OSS format.
* @param model_out Pointer to the Model object that will be created. Must be released with gptoss_release_model.
*
* On success, returns gptoss_status_success and saves a pointer to the created Model in the model_out argument.
* On failure, returns an error code and stores null pointer in the model_out argument.
*/
enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
const char* path,
gptoss_model_t* model_out);
/*
* Query the Tokenizer object associated with the Model.
*
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
* @param tokenizer_out Pointer to the variable where the Tokenizer reference will be stored.
*
* On success, returns gptoss_status_success and stores reference to the Tokenizer object in the tokenizer_out argument.
* On failure, returns an error code and stores NULL in the tokenizer_out argument.
*/
enum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(
gptoss_model_t model,
gptoss_tokenizer_t* tokenizer_out);
/*
* Query the maximum context length supported by the Model.
*
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
* @param max_context_length_out Pointer to the variable where the maximum context length will be stored.
*
* On success, returns gptoss_status_success and stores maximum context length in the max_context_length_out argument.
* On failure, returns an error code and leaves the value specified by max_context_length_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(
gptoss_model_t model,
size_t* max_context_length_out);
/*
* Increments a Model object's reference count.
*
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_model_retain(
gptoss_model_t model);
/*
* Decrements a Model object's reference count and possibly release associated resources.
*
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_model_release(
gptoss_model_t model);
/*
* Query the token ID for a special token in the Tokenizer vocabulary.
*
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
* @param token_type Type of the special token to query an ID for.
* @param token_id_out Pointer to the variable where the token ID will be stored.
*
* On success, returns gptoss_status_success and stores the token ID in the token_id_out argument.
* On failure, returns an error code and leaves the value specified by token_id_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(
gptoss_tokenizer_t tokenizer,
enum gptoss_special_token token_type,
uint32_t* token_id_out);
/*
* Query the number of text tokens in the Tokenizer vocabulary.
*
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
* @param num_text_tokens_out Pointer to the variable where the number of text tokens will be stored.
*
* On success, returns gptoss_status_success and stores the number of text tokens in the num_text_tokens_out argument.
* On failure, returns an error code and leaves the value specified by num_text_tokens_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(
gptoss_tokenizer_t tokenizer,
uint32_t* num_text_tokens_out);
/*
* Query the number of special tokens in the Tokenizer vocabulary.
*
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
* @param num_special_tokens_out Pointer to the variable where the number of special tokens will be stored.
*
* On success, returns gptoss_status_success and stores the number of text tokens in the num_special_tokens_out argument.
* On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(
gptoss_tokenizer_t tokenizer,
uint32_t* num_special_tokens_out);
/*
* Query the total number of tokens in the Tokenizer vocabulary.
*
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
* @param num_tokens_out Pointer to the variable where the total number of tokens will be stored.
*
* On success, returns gptoss_status_success and stores the total number of tokens in the num_special_tokens_out argument.
* On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(
gptoss_tokenizer_t tokenizer,
uint32_t* num_tokens_out);
/*
* Convert a text token ID to byte representation.
*
* @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer. The lifetime of the returned
* byte representation would match the lifetime of this Tokenizer object.
* @param token_ptr_out Pointer to the variable where the pointer to the byte representation of the token will be
* stored.
* @param token_size_out Pointer to the variable where the size of the byte representation of the token will be stored.
*
* On success, returns gptoss_status_success and stores pointer and size of the byte representation of the token in the
* token_ptr_out and token_size_out arguments.
* On failure, returns an error code and leaves the values specified in token_ptr_out and token_size_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(
gptoss_tokenizer_t tokenizer,
uint32_t token_id,
const void** token_ptr_out,
size_t* token_size_out);
/*
* Increments a Tokenizer object's reference count.
*
* @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(
gptoss_tokenizer_t tokenizer);
/*
* Decrements a Tokenizer object's reference count and possibly release associated resources.
*
* @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(
gptoss_tokenizer_t tokenizer);
/*
* Creates a Context object for use with the particular Model object.
*
* @param model Model object to create a context for.
* @param context_length Maximum number of tokens in the context.
* Specify 0 to use the maximum context length supported by the model.
* @param max_batch_size Maximum number of tokens that can be processed in a single batch.
* Larger values may improve prefill performance, but require more memory.
* Specify 0 to use the default value.
* @param context_out Pointer to the Context object that will be created.
* Must be released with gptoss_release_context.
*
* On success, returns gptoss_status_success and saves a pointer to the created Context in the context_out argument.
* On failure, returns an error code and stores null pointer in the context_out argument.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_create(
gptoss_model_t model,
size_t context_length,
size_t max_batch_tokens,
gptoss_context_t* context_out);
/*
* Query the current number of tokens cached in the Context.
*
* @param context Pointer to the Context object created by gptoss_context_create.
* @param num_tokens_out Pointer to the variable where the current number of cached tokens will be stored.
*
* On success, returns gptoss_status_success and stores current number of cached tokens in the num_tokens_out argument.
* On failure, returns an error code and leaves the value specified by num_tokens_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
gptoss_context_t context,
size_t* num_tokens_out);
/*
* Query the maximum number of tokens cached in the Context.
*
* @param context Pointer to the Context object created by gptoss_context_create.
* @param max_tokens_out Pointer to the variable where the maximum number of cached tokens will be stored.
*
* On success, returns gptoss_status_success and stores maximum number of cached tokens in the max_tokens_out argument.
* On failure, returns an error code and leaves the value specified by max_tokens_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
gptoss_context_t context,
size_t* max_tokens_out);
/*
* Query the list of token IDs cached in the Context.
*
* @param context Pointer to the Context object created by gptoss_context_create.
* @param tokens_out Pointer to the array where up to max_tokens_out of cached tokens will be stored.
* @param max_tokens Maximum capacity of the buffer specified by tokens_out.
* @param num_tokens_out Pointer to the variable where the actual number of cached tokens will be stored.
* This value can exceed max_tokens if the buffer capacity is insufficient.
*
* On success, returns gptoss_status_success and stores cached token IDs in the tokens_out argument and the number of
* cached tokens in the num_tokens_out argument.
* On failure, returns an error code and leaves the values specified by tokens_out and num_tokens_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
gptoss_context_t context,
uint32_t* tokens_out,
size_t max_tokens,
size_t* num_tokens_out);
/*
* Tokenize and appends a character string to the Context object.
*
* @param context Context object created by gptoss_context_create.
* @param text Pointer to the character string to tokenizer and append.
* @param text_length Length of the string, in chars.
* @param num_tokens_out Optional pointer to the variable where the number of appended tokens will be stored. Ignored if a null pointer is provided.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
gptoss_context_t context,
const char* text,
size_t text_length,
size_t* num_tokens_out);
/*
* Appends a list of tokens to the context.
*
* @param context Context object created by gptoss_context_create.
* @param num_tokens Number of tokens to be appended.
* @param tokens Pointer to the array of tokens to be appended.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
gptoss_context_t context,
size_t num_tokens,
const uint32_t* tokens);
/*
* Resets the context, clearing its state.
*
* @param context Context object created by gptoss_context_create.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_reset(
gptoss_context_t context);
/*
* Pre-process the tokens in the Context and generate probability distribution over the next token.
*
* @param context Context object created by gptoss_context_create.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_process(
gptoss_context_t context);
/*
* Generate a token probability distribution over the next token conditioned on the Context.
*
* @param context Context object created by gptoss_context_create.
* @param temperature Sampling temperature. Must be non-negative.
* @param seed Random number generator seed to use for sampling.
* @param token_out Pointer to the variable where the token ID will be stored.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_sample(
gptoss_context_t context,
float temperature,
uint64_t seed,
size_t max_tokens,
uint32_t* tokens_out,
size_t* num_tokens_out);
/*
* Increments a Context object's reference count.
*
* @param context Pointer to the Context object created by gptoss_create_context.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_retain(
gptoss_context_t context);
/*
* Decrements a Context object's reference count and possibly release associated resources.
*
* @param context Pointer to the Context object created by gptoss_create_context.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_release(
gptoss_context_t context);
/*
* Creates a Sampler object.
*
* @param sampler_out Pointer to the Sampler object that will be created.
* Must be released with gptoss_sampler_release.
*
* On success, returns gptoss_status_success and saves a pointer to the created Sampler in the sampler_out argument.
* On failure, returns an error code and stores a null pointer in the sampler_out argument.
*/
enum gptoss_status GPTOSS_ABI gptoss_sampler_create(
gptoss_sampler_t* sampler_out);
/*
* Sets the sampling temperature for the Sampler.
*
* @param sampler Sampler object created by gptoss_sampler_create.
* @param temperature Temperature value to be set. Must be in the [0.0, 1.0] range.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_temperature(
gptoss_sampler_t sampler,
float temperature);
/*
* Sets the Top-P nucleus sampling parameter for the Sampler.
*
* @param sampler Sampler object created by gptoss_sampler_create.
* @param top_p Top-P value to be set. Must be in the (0.0, 1.0] range.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_top_p(
gptoss_sampler_t sampler,
float top_p);
/*
* Sets the presence penalty for the Sampler.
*
* @param sampler Sampler object created by gptoss_sampler_create.
* @param presence_penalty Presence penalty value to be set. Must be in the [-2.0, 2.0] range.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_presence_penalty(
gptoss_sampler_t sampler,
float presence_penalty);
/*
* Sets the frequency penalty for the Sampler.
*
* @param sampler Sampler object created by gptoss_sampler_create.
* @param frequency_penalty Frequency penalty value to be set. Must be in the [-2.0, 2.0] range.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_frequency_penalty(
gptoss_sampler_t sampler,
float frequency_penalty);
/*
* Increments a Sampler object's reference count.
*
* @param sampler Pointer to the Sampler object created by gptoss_sampler_create.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_sampler_retain(
gptoss_sampler_t sampler);
/*
* Decrements a Sampler object's reference count and possibly releases associated resources.
*
* @param sampler Pointer to the Sampler object created by gptoss_sampler_create.
*
* On success, returns gptoss_status_success, otherwise returns an error code.
*/
enum gptoss_status GPTOSS_ABI gptoss_sampler_release(
gptoss_sampler_t sampler);
#ifdef __cplusplus
} // extern "C"
#endif
================================================
FILE: gpt_oss/metal/include/gpt-oss/macros.h
================================================
#pragma once
#ifndef GPTOSS_ABI
#define GPTOSS_ABI
#endif // GPTOSS_ABI
================================================
FILE: gpt_oss/metal/include/gpt-oss/types.h
================================================
#pragma once
/*
* Status codes returned by GPT-OSS API functions.
*/
enum gptoss_status {
gptoss_status_success = 0,
gptoss_status_invalid_argument = 1,
gptoss_status_unsupported_argument = 2,
gptoss_status_invalid_state = 3,
gptoss_status_io_error = 4,
gptoss_status_insufficient_memory = 5,
gptoss_status_insufficient_resources = 6,
gptoss_status_unsupported_system = 7,
gptoss_status_context_overflow = 8,
};
enum gptoss_special_token {
gptoss_special_token_invalid = 0,
gptoss_special_token_return = 1,
gptoss_special_token_start = 2,
gptoss_special_token_message = 3,
gptoss_special_token_end = 4,
gptoss_special_token_refusal = 5,
gptoss_special_token_constrain = 6,
gptoss_special_token_channel = 7,
gptoss_special_token_call = 8,
gptoss_special_token_untrusted = 9,
gptoss_special_token_end_untrusted = 10,
gptoss_special_token_max,
};
/*
* Model object is an opaque container comprised of:
* - Weights
* - Temporary buffers required to run the model
* - Any other resources requires to run the model
*/
typedef struct gptoss_model* gptoss_model_t;
typedef struct gptoss_tokenizer* gptoss_tokenizer_t;
/*
* Context is an opaque container comprised of:
* - Input tokens
* - Distribution over the output tokens
* - KV cache
*
* Multiple contexts can be created and used with the same model.
*/
typedef struct gptoss_context* gptoss_context_t;
/*
* Sampler is an opaque container for sampling parameters:
* - Temperature
* - Top-p (nucleus sampling)
* - Frequency penalty
* - Presence penalty
*
* Multiple samplers can be created and used with the same context.
*/
typedef struct gptoss_sampler* gptoss_sampler_t;
================================================
FILE: gpt_oss/metal/include/gpt-oss.h
================================================
#pragma once
#include
#include
#include
================================================
FILE: gpt_oss/metal/python/context.c
================================================
#include
#include
#include "module.h"
static int PyGPTOSSContext_init(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) {
static char *kwlist[] = {"model", "context_length", "max_batch_tokens", NULL};
PyObject* model = NULL;
Py_ssize_t context_length = 0; // Default to 0 if None
Py_ssize_t max_batch_tokens = 0; // Default to 0 if None
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|$ii", kwlist,
&PyGPTOSSModel_Type, &model,
&context_length, &max_batch_tokens))
{
return -1;
}
if (context_length < 0) {
PyErr_SetString(PyExc_ValueError, "context_length must be a positive integer");
return -1;
}
if (max_batch_tokens < 0) {
PyErr_SetString(PyExc_ValueError, "max_batch_tokens must be a positive integer");
return -1;
}
enum gptoss_status status = gptoss_context_create(
((const PyGPTOSSModel*) model)->handle,
(size_t) context_length,
(size_t) max_batch_tokens,
&self->handle);
if (status != gptoss_status_success) {
// TODO: set exception
goto error;
}
return 0;
error:
gptoss_context_release(self->handle);
self->handle = NULL;
return -1;
}
static void PyGPTOSSContext_dealloc(PyGPTOSSContext* self) {
(void) gptoss_context_release(self->handle);
self->handle = NULL;
PyObject_Del((PyObject*) self);
}
static PyObject* PyGPTOSSContext_copy(PyGPTOSSContext *self) {
PyGPTOSSContext* copy = (PyGPTOSSContext*) PyObject_New(PyGPTOSSContext, Py_TYPE(self));
if (copy == NULL) {
return NULL;
}
(void) gptoss_context_retain(self->handle);
copy->handle = self->handle;
return (PyObject*) copy;
}
static PyObject* PyGPTOSSContext_append(PyGPTOSSContext* self, PyObject* arg) {
if (PyBytes_Check(arg)) {
char* string_ptr = NULL;
Py_ssize_t string_size = 0;
if (PyBytes_AsStringAndSize(arg, &string_ptr, &string_size) < 0) {
return NULL;
}
const enum gptoss_status status = gptoss_context_append_chars(
self->handle, string_ptr, string_size, /*num_tokens_out=*/NULL);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
Py_RETURN_NONE;
} else if (PyUnicode_Check(arg)) {
Py_ssize_t string_size = 0;
const char* string_ptr = PyUnicode_AsUTF8AndSize(arg, &string_size);
if (string_ptr == NULL) {
return NULL;
}
const enum gptoss_status status = gptoss_context_append_chars(
self->handle, string_ptr, string_size, /*num_tokens_out=*/NULL);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
Py_RETURN_NONE;
} else if (PyLong_Check(arg)) {
const unsigned long token_as_ulong = PyLong_AsUnsignedLong(arg);
if (token_as_ulong == (unsigned long) -1 && PyErr_Occurred()) {
return NULL;
}
const uint32_t token = (uint32_t) token_as_ulong;
const enum gptoss_status status = gptoss_context_append_tokens(
self->handle, /*num_tokens=*/1, &token);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
Py_RETURN_NONE;
} else {
PyErr_SetString(PyExc_TypeError, "expected a bytes or integer argument");
return NULL;
}
}
static PyObject* PyGPTOSSContext_process(PyGPTOSSContext* self) {
const enum gptoss_status status = gptoss_context_process(self->handle);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
Py_RETURN_NONE;
}
static PyObject* PyGPTOSSContext_sample(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) {
static char *kwlist[] = {"max_output_tokens", "temperature", "seed", NULL};
PyObject* token_list_obj = NULL;
uint32_t* token_ptr = NULL;
unsigned int max_output_tokens = 0;
unsigned long long seed = 0;
float temperature = 1.0f;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "I|$fK", kwlist,
&max_output_tokens, &temperature, &seed))
{
return NULL;
}
token_ptr = (uint32_t*) PyMem_Malloc(max_output_tokens * sizeof(uint32_t));
if (token_ptr == NULL) {
goto error;
}
size_t num_tokens = 0;
const enum gptoss_status status = gptoss_context_sample(
self->handle, temperature, (uint64_t) seed,
(size_t) max_output_tokens, token_ptr, &num_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
goto error;
}
token_list_obj = PyList_New((Py_ssize_t) num_tokens);
if (token_list_obj == NULL) {
goto error;
}
for (size_t t = 0; t < num_tokens; t++) {
PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
if (token_obj == NULL) {
goto error;
}
PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);
}
PyMem_Free(token_ptr);
return token_list_obj;
error:
PyMem_Free(token_ptr);
Py_XDECREF(token_list_obj);
return NULL;
}
static PyObject* PyGPTOSSContext_reset(PyGPTOSSContext* self) {
const enum gptoss_status status = gptoss_context_reset(self->handle);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
Py_RETURN_NONE;
}
static PyMethodDef PyGPTOSSContext_methods[] = {
{"__copy__", (PyCFunction) PyGPTOSSContext_copy, METH_NOARGS, "Create a copy of the Context"},
{"append", (PyCFunction) PyGPTOSSContext_append, METH_O, "Append bytes to the Context"},
{"process", (PyCFunction) PyGPTOSSContext_process, METH_NOARGS, "Process tokens in the Context"},
{"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token predictions from the Context"},
{"reset", (PyCFunction) PyGPTOSSContext_reset, METH_NOARGS, "Discard the content of the Context"},
{NULL},
};
static PyObject* PyGPTOSSContext_get_num_tokens(PyGPTOSSContext* self, void* closure) {
size_t num_tokens = 0;
const enum gptoss_status status = gptoss_context_get_num_tokens(self->handle, &num_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
return PyLong_FromSize_t(num_tokens);
}
static PyObject* PyGPTOSSContext_get_max_tokens(PyGPTOSSContext* self, void* closure) {
size_t max_tokens = 0;
const enum gptoss_status status = gptoss_context_get_max_tokens(self->handle, &max_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
return PyLong_FromSize_t(max_tokens);
}
static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure) {
PyObject* token_list_obj = NULL;
uint32_t* token_ptr = NULL;
size_t num_tokens = 0;
gptoss_context_get_tokens(self->handle, /*tokens_out=*/NULL, /*max_tokens=*/0, &num_tokens);
if (num_tokens != 0) {
token_ptr = (uint32_t*) PyMem_Malloc(num_tokens * sizeof(uint32_t));
if (token_ptr == NULL) {
// TODO: set exception
goto error;
}
enum gptoss_status status = gptoss_context_get_tokens(self->handle, token_ptr, /*max_tokens=*/num_tokens, &num_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
goto error;
}
}
token_list_obj = PyList_New((Py_ssize_t) num_tokens);
if (token_list_obj == NULL) {
goto error;
}
for (size_t t = 0; t < num_tokens; t++) {
PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
if (token_obj == NULL) {
goto error;
}
PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);
}
PyMem_Free(token_ptr);
return token_list_obj;
error:
PyMem_Free(token_ptr);
Py_XDECREF(token_list_obj);
return NULL;
}
static PyGetSetDef PyGPTOSSContext_getseters[] = {
(PyGetSetDef) {
.name = "num_tokens",
.get = (getter) PyGPTOSSContext_get_num_tokens,
.doc = "Current number of tokens in the context",
},
(PyGetSetDef) {
.name = "max_tokens",
.get = (getter) PyGPTOSSContext_get_max_tokens,
.doc = "Maximum number of tokens in the context",
},
(PyGetSetDef) {
.name = "tokens",
.get = (getter) PyGPTOSSContext_get_tokens,
.doc = "List of token IDs in the context",
},
{NULL} /* Sentinel */
};
PyTypeObject PyGPTOSSContext_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "gptoss.Context",
.tp_basicsize = sizeof(PyGPTOSSContext),
.tp_flags = 0
| Py_TPFLAGS_DEFAULT
| Py_TPFLAGS_BASETYPE,
.tp_doc = "Context object",
.tp_methods = PyGPTOSSContext_methods,
.tp_getset = PyGPTOSSContext_getseters,
.tp_new = PyType_GenericNew,
.tp_init = (initproc) PyGPTOSSContext_init,
.tp_dealloc = (destructor) PyGPTOSSContext_dealloc,
};
================================================
FILE: gpt_oss/metal/python/model.c
================================================
#include
#include
#include "module.h"
static int PyGPTOSSModel_init(PyGPTOSSModel* self, PyObject* args, PyObject* kwargs) {
enum gptoss_status status;
const char* filepath;
if (!PyArg_ParseTuple(args, "s", &filepath)) {
return -1;
}
status = gptoss_model_create_from_file(filepath, &self->handle);
if (status != gptoss_status_success) {
// TODO: set exception
return -1;
}
return 0;
}
static void PyGPTOSSModel_dealloc(PyGPTOSSModel* self) {
(void) gptoss_model_release(self->handle);
self->handle = NULL;
PyObject_Del((PyObject*) self);
}
static PyObject* PyGPTOSSModel_copy(PyGPTOSSModel* self) {
PyGPTOSSModel* copy = (PyGPTOSSModel*) PyObject_New(PyGPTOSSModel, Py_TYPE(self));
if (copy == NULL) {
return NULL;
}
(void) gptoss_model_retain(self->handle);
copy->handle = self->handle;
return (PyObject*) copy;
}
static PyMethodDef PyGPTOSSModel_methods[] = {
{"__copy__", (PyCFunction) PyGPTOSSModel_copy, METH_NOARGS, "Create a copy of the Model"},
{NULL},
};
static PyObject *PyGPTOSSModel_get_max_context_length(PyGPTOSSModel* self, void* closure) {
size_t max_context_length = 0;
const enum gptoss_status status = gptoss_model_get_max_context_length(self->handle, &max_context_length);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
return PyLong_FromSize_t(max_context_length);
}
static PyObject *PyGPTOSSModel_get_tokenizer(PyGPTOSSModel* self, void* closure) {
PyObject* args = PyTuple_Pack(1, self);
if (args == NULL) {
return NULL;
}
PyObject* tokenizer = PyObject_CallObject((PyObject*) &PyGPTOSSTokenizer_Type, args);
Py_DECREF(args);
return tokenizer;
}
static PyGetSetDef PyGPTOSSModel_getseters[] = {
(PyGetSetDef) {
.name = "max_context_length",
.get = (getter) PyGPTOSSModel_get_max_context_length,
.doc = "Maximum context length supported by the model",
},
(PyGetSetDef) {
.name = "tokenizer",
.get = (getter) PyGPTOSSModel_get_tokenizer,
.doc = "Tokenizer object associated with the model",
},
{NULL} // Sentinel
};
PyTypeObject PyGPTOSSModel_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "gptoss.Model",
.tp_basicsize = sizeof(PyGPTOSSModel),
.tp_flags = 0
| Py_TPFLAGS_DEFAULT
| Py_TPFLAGS_BASETYPE,
.tp_doc = "Model object",
.tp_methods = PyGPTOSSModel_methods,
.tp_getset = PyGPTOSSModel_getseters,
.tp_new = PyType_GenericNew,
.tp_init = (initproc) PyGPTOSSModel_init,
.tp_dealloc = (destructor) PyGPTOSSModel_dealloc,
};
================================================
FILE: gpt_oss/metal/python/module.c
================================================
#include
#include "module.h"
static PyMethodDef module_methods[] = {
{NULL, NULL, 0, NULL}
};
static PyModuleDef metal_module = {
PyModuleDef_HEAD_INIT,
"_metal",
"Local GPT-OSS inference",
-1,
module_methods
};
PyMODINIT_FUNC PyInit__metal(void) {
PyObject* module = NULL;
PyObject* model_type = NULL;
PyObject* tokenizer_type = NULL;
PyObject* context_type = NULL;
if (PyType_Ready(&PyGPTOSSModel_Type) < 0) {
goto error;
}
model_type = (PyObject*) &PyGPTOSSModel_Type;
Py_INCREF(model_type);
if (PyType_Ready(&PyGPTOSSTokenizer_Type) < 0) {
goto error;
}
tokenizer_type = (PyObject*) &PyGPTOSSTokenizer_Type;
Py_INCREF(tokenizer_type);
if (PyType_Ready(&PyGPTOSSContext_Type) < 0) {
goto error;
}
context_type = (PyObject*) &PyGPTOSSContext_Type;
Py_INCREF(context_type);
module = PyModule_Create(&metal_module);
if (module == NULL) {
goto error;
}
if (PyModule_AddObject(module, "Model", model_type) < 0) {
goto error;
}
if (PyModule_AddObject(module, "Tokenizer", tokenizer_type) < 0) {
goto error;
}
if (PyModule_AddObject(module, "Context", context_type) < 0) {
goto error;
}
return module;
error:
Py_XDECREF(context_type);
Py_XDECREF(tokenizer_type);
Py_XDECREF(model_type);
Py_XDECREF(module);
return NULL;
}
================================================
FILE: gpt_oss/metal/python/module.h
================================================
#include
#include
typedef struct {
PyObject_HEAD
gptoss_model_t handle;
} PyGPTOSSModel;
typedef struct {
PyObject_HEAD
gptoss_tokenizer_t handle;
} PyGPTOSSTokenizer;
typedef struct {
PyObject_HEAD
gptoss_context_t handle;
} PyGPTOSSContext;
extern PyTypeObject PyGPTOSSModel_Type;
extern PyTypeObject PyGPTOSSTokenizer_Type;
extern PyTypeObject PyGPTOSSContext_Type;
================================================
FILE: gpt_oss/metal/python/tokenizer.c
================================================
#include
#include
#include "module.h"
static PyObject* PyGPTOSSTokenizer_new(PyTypeObject* subtype, PyObject* args, PyObject* kwargs) {
static char *kwlist[] = {"model", NULL};
PyObject* model = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyGPTOSSModel_Type, &model)) {
return NULL;
}
PyGPTOSSTokenizer* self = (PyGPTOSSTokenizer*) subtype->tp_alloc(subtype, 0);
if (self == NULL) {
return NULL;
}
const enum gptoss_status status = gptoss_model_get_tokenizer(
((const PyGPTOSSModel*) model)->handle,
&self->handle);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
return (PyObject*) self;
}
static void PyGPTOSSTokenizer_dealloc(PyGPTOSSTokenizer* self) {
(void) gptoss_tokenizer_release(self->handle);
self->handle = NULL;
PyObject_Del((PyObject*) self);
}
static PyObject* PyGPTOSSTokenizer_copy(PyGPTOSSTokenizer* self) {
PyGPTOSSTokenizer* copy = (PyGPTOSSTokenizer*) PyObject_New(PyGPTOSSTokenizer, Py_TYPE(self));
if (copy == NULL) {
return NULL;
}
(void) gptoss_tokenizer_retain(self->handle);
copy->handle = self->handle;
return (PyObject*) copy;
}
static PyObject* PyGPTOSSTokenizer_encode_special_token(PyGPTOSSTokenizer* self, PyObject* arg) {
if (PyUnicode_Check(arg)) {
const char* string_ptr = PyUnicode_AsUTF8(arg);
if (string_ptr == NULL) {
return NULL;
}
enum gptoss_special_token token_type = gptoss_special_token_invalid;
if (strcmp(string_ptr, "<|return|>") == 0) {
token_type = gptoss_special_token_return;
} else if (strcmp(string_ptr, "<|start|>") == 0) {
token_type = gptoss_special_token_start;
} else if (strcmp(string_ptr, "<|message|>") == 0) {
token_type = gptoss_special_token_message;
} else if (strcmp(string_ptr, "<|end|>") == 0) {
token_type = gptoss_special_token_end;
} else if (strcmp(string_ptr, "<|refusal|>") == 0) {
token_type = gptoss_special_token_refusal;
} else if (strcmp(string_ptr, "<|constrain|>") == 0) {
token_type = gptoss_special_token_constrain;
} else if (strcmp(string_ptr, "<|channel|>") == 0) {
token_type = gptoss_special_token_channel;
} else if (strcmp(string_ptr, "<|call|>") == 0) {
token_type = gptoss_special_token_call;
} else if (strcmp(string_ptr, "<|untrusted|>") == 0) {
token_type = gptoss_special_token_untrusted;
} else if (strcmp(string_ptr, "<|end_untrusted|>") == 0) {
token_type = gptoss_special_token_end_untrusted;
} else {
PyErr_Format(PyExc_ValueError, "unrecognized special token: %s", string_ptr);
return NULL;
}
uint32_t token_id = UINT32_MAX;
const enum gptoss_status status = gptoss_tokenizer_get_special_token_id(
self->handle, token_type, &token_id);
if (status != gptoss_status_success || token_id == UINT32_MAX) {
PyErr_Format(PyExc_ValueError, "tokenizer does not support the %s token", string_ptr);
return NULL;
}
return PyLong_FromUnsignedLong((unsigned long) token_id);
} else {
PyErr_SetString(PyExc_TypeError, "string argument expected");
return NULL;
}
}
static PyObject* PyGPTOSSTokenizer_decode(PyGPTOSSTokenizer* self, PyObject* args, PyObject* kwargs) {
static char *kwlist[] = {"token", NULL};
unsigned int token = 0; // Default to 0 if None
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "I", kwlist, &token)) {
return NULL;
}
const void* token_ptr = NULL;
size_t token_size = 0;
const enum gptoss_status status = gptoss_tokenizer_decode(self->handle, (uint32_t) token, &token_ptr, &token_size);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
return PyBytes_FromStringAndSize((const char*) token_ptr, (Py_ssize_t) token_size);
}
static PyMethodDef PyGPTOSSTokenizer_methods[] = {
{"__copy__", (PyCFunction) PyGPTOSSTokenizer_copy, METH_NOARGS, "Create a copy of the Tokenizer"},
{"encode_special_token", (PyCFunction) PyGPTOSSTokenizer_encode_special_token, METH_O, "Query ID of a special token"},
{"decode", (PyCFunction) PyGPTOSSTokenizer_decode, METH_VARARGS | METH_KEYWORDS, "Convert text token ID to bytes"},
{NULL},
};
static PyObject* PyGPTOSSTokenizer_get_num_text_tokens(PyGPTOSSTokenizer* self, void* closure) {
uint32_t num_text_tokens = 0;
const enum gptoss_status status = gptoss_tokenizer_get_num_text_tokens(self->handle, &num_text_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
return PyLong_FromUnsignedLong((unsigned long) num_text_tokens);
}
static PyObject* PyGPTOSSTokenizer_get_num_special_tokens(PyGPTOSSTokenizer* self, void* closure) {
uint32_t num_special_tokens = 0;
const enum gptoss_status status = gptoss_tokenizer_get_num_special_tokens(self->handle, &num_special_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
return PyLong_FromUnsignedLong((unsigned long) num_special_tokens);
}
static PyObject* PyGPTOSSTokenizer_get_num_tokens(PyGPTOSSTokenizer* self, void* closure) {
uint32_t num_tokens = 0;
const enum gptoss_status status = gptoss_tokenizer_get_num_tokens(self->handle, &num_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
}
return PyLong_FromUnsignedLong((unsigned long) num_tokens);
}
static PyGetSetDef PyGPTOSSTokenizer_getseters[] = {
(PyGetSetDef) {
.name = "num_tokens",
.get = (getter) PyGPTOSSTokenizer_get_num_tokens,
.doc = "Total number of tokens in the tokenizer dictionary",
},
(PyGetSetDef) {
.name = "num_text_tokens",
.get = (getter) PyGPTOSSTokenizer_get_num_text_tokens,
.doc = "Number of text tokens in the tokenizer dictionary",
},
(PyGetSetDef) {
.name = "num_special_tokens",
.get = (getter) PyGPTOSSTokenizer_get_num_special_tokens,
.doc = "Number of special tokens in the tokenizer dictionary",
},
{NULL} /* Sentinel */
};
PyTypeObject PyGPTOSSTokenizer_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "gptoss.Tokenizer",
.tp_basicsize = sizeof(PyGPTOSSTokenizer),
.tp_flags = 0
| Py_TPFLAGS_DEFAULT
| Py_TPFLAGS_BASETYPE,
.tp_doc = "Tokenizer object",
.tp_methods = PyGPTOSSTokenizer_methods,
.tp_getset = PyGPTOSSTokenizer_getseters,
.tp_new = PyGPTOSSTokenizer_new,
.tp_dealloc = (destructor) PyGPTOSSTokenizer_dealloc,
};
================================================
FILE: gpt_oss/metal/scripts/create-local-model.py
================================================
import argparse
import os
import math
import sys
import json
import itertools
import struct
from uuid import UUID
import tiktoken
import torch
from safetensors import safe_open
from tqdm import tqdm
from openai_harmony import load_harmony_encoding, HarmonyEncodingName
parser = argparse.ArgumentParser(prog='create-local-model.py', description='Convert a checkpoint directory to a local model file')
parser.add_argument('-s', '--src', metavar='DIR', type=str, required=True, help='Path to the input checkpoint directory')
parser.add_argument('-d', '--dst', metavar='FILE', type=str, required=True, help='Path to the output model file')
o200k_base = tiktoken.get_encoding("o200k_base")
harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
o200k_gptoss = tiktoken.Encoding(
name="o200k_gptoss",
pat_str=o200k_base._pat_str,
mergeable_ranks=o200k_base._mergeable_ranks,
special_tokens={
"<|reversed199998|>": 199998, # unused
"<|endoftext|>": 199999,
"<|untrusted|>": 200000,
"<|endofuntrusted|>": 200001,
"<|return|>": 200002,
"<|constrain|>": 200003,
"<|reversed200004|>": 200004, # unused
"<|channel|>": 200005,
"<|start|>": 200006,
"<|end|>": 200007,
"<|message|>": 200008,
"<|reversed200008|>": 200008, # unused
"<|reversed200009|>": 200009, # unused
"<|reversed200010|>": 200010, # unused
"<|reversed200011|>": 200011, # unused
"<|call|>": 200012,
"<|refusal|>": 200013,
}
)
FILE_MAGIC = struct.pack('ccccccccccccI', b'G', b'P', b'T', b'-', b'O', b'S', b'S', b' ', b'v', b'1', b'.', b'0', 0)
SPECIAL_TOKEN_UUID = {
'<|start|>': UUID('55a77c2f-8a01-4c54-8ac2-313bfc7e208d').bytes,
'<|message|>': UUID('16e40431-f47f-4b22-b59b-8b278fc30a54').bytes,
'<|end|>': UUID('fcac2f6d-4705-4f6b-b228-642accac7238').bytes,
'<|return|>': UUID('f799ff69-1992-43c4-a3d8-d831f475dc75').bytes,
'<|refusal|>': UUID('e15ba702-28c4-4292-ab8f-ffa434709128').bytes,
'<|constrain|>': UUID('c0bb14c7-6022-49da-ad08-792d67e8b470').bytes,
'<|channel|>': UUID('fd3dda11-c8ab-4033-876e-d93deb172c93').bytes,
'<|call|>': UUID('1220f796-e388-4de5-b487-fe2eb5fe03c0').bytes,
'<|untrusted|>': UUID('07d7da55-b346-4cff-8b37-7cefacf8a3e8').bytes,
'<|end_untrusted|>': UUID('f265bd9c-c717-469e-a447-920687d65d90').bytes,
}
INCLUDE_SPECIAL_TOKENS = [
"<|start|>",
"<|message|>",
"<|end|>",
"<|return|>",
"<|refusal|>",
"<|constrain|>",
"<|channel|>",
"<|call|>",
"<|untrusted|>",
"<|end_untrusted|>",
]
GPTOSS_MODEL_UUID = UUID('df52dc86-1789-4ed0-a295-66f10508145b').bytes
APPLE_GPU_LAYOUT_UUID = UUID('229177a8-5775-4268-bfd8-d588b351c56d').bytes
TIKTOKEN_TOKENIZER_UUID = UUID('7401aded-2a95-40cb-b782-9ccebaafe72b').bytes
UE8_OFFSET = 14 # bias to MXFP4 block scales
def write_file_header(f):
f.write(FILE_MAGIC)
def write_tokenizer_header(f,
num_special_tokens: int,
num_text_tokens: int,
regex_size: int,
tokens_size: int):
f.write(TIKTOKEN_TOKENIZER_UUID)
f.write(struct.pack(' 0
tokens_size += len(token_bytes) + 2 # uint16_t string length + string data
num_text_tokens += 1
# Then add all special tokens
num_included_tokens = 200013 + 1
print(f"Tokenizer: {num_included_tokens} tokens")
# Read from all files ending with .safetensors in the checkpoint directory
safetensor_files = [
os.path.join(options.src, fname)
for fname in os.listdir(options.src)
if fname.endswith(".safetensors")
]
# Build a mapping from tensor name to filepath
tensor_name_to_file = {}
for safetensor_file in safetensor_files:
with safe_open(safetensor_file, framework="pt", device="cpu") as src:
for key in src.keys():
tensor_name_to_file[key] = safetensor_file
def get_tensor(name):
with safe_open(tensor_name_to_file[name], framework="pt", device="cpu") as src:
return src.get_tensor(name)
with open(options.dst, "wb") as dst:
write_file_header(dst)
yarn_low = (
head_dim / 2
* math.log(initial_context_length / (rope_ntk_beta * 2 * math.pi))
/ math.log(rope_theta)
)
yarn_high = (
head_dim / 2
* math.log(initial_context_length / (rope_ntk_alpha * 2 * math.pi))
/ math.log(rope_theta)
)
write_model_header(dst,
context_length=int(initial_context_length * rope_scaling_factor),
num_blocks=num_blocks,
num_experts=num_experts,
num_active_experts=num_active_experts,
embedding_dim=embedding_dim,
mlp_dim=mlp_dim,
swiglu_limit=swiglu_limit,
head_dim=head_dim,
num_heads=num_q_heads,
num_kv_heads=num_kv_heads,
attention_window=attention_window,
rope_theta=rope_theta,
interpolation_scale=1.0 / rope_scaling_factor,
yarn_offset=-yarn_low / (yarn_high - yarn_low),
yarn_scale=1.0 / (yarn_high - yarn_low),
yarn_multiplier=0.1 * math.log(rope_scaling_factor) + 1.0,
rmsnorm_epsilon=1.0e-5)
write_tokenizer_header(dst,
num_special_tokens=num_included_tokens - num_text_tokens,
num_text_tokens=num_text_tokens,
regex_size=len(o200k_gptoss._pat_str.encode("ascii")) + 1,
tokens_size=tokens_size)
### Tokenizer
# Special tokens
for token_idx in range(num_text_tokens, num_included_tokens):
token = o200k_gptoss.decode_single_token_bytes(token_idx).decode('ascii')
if token in INCLUDE_SPECIAL_TOKENS:
dst.write(SPECIAL_TOKEN_UUID[token])
else:
dst.write(bytes(16))
# Regex
dst.write(o200k_gptoss._pat_str.encode("ascii"))
dst.write(struct.pack('B', 0))
# Text tokens
tokenizer_bytes_written = 0
for t in range(num_text_tokens):
token_bytes = o200k_gptoss.decode_single_token_bytes(t)
assert len(token_bytes) > 0
dst.write(struct.pack('
#include
#include
#pragma METAL fp math_mode(safe)
#pragma METAL fp contract(off)
kernel void gptoss_f32_accumulate_e4(
constant gptoss_accumulate_args& args [[ buffer(0) ]],
const device float4* input [[ buffer(1) ]],
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
const device gptoss_control* control [[ buffer(4) ]],
uint2 gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint2 threadgroup_size [[ threads_per_threadgroup ]])
{
const uint num_active_experts = 4;
if (control->abort != 0) {
return;
}
const uint num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
const uint threadgroup_start = gid.x * num_vecs_per_threadgroup;
const uint num_vecs = args.num_vecs;
const uint threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, num_vecs);
const uint thread_start = threadgroup_start + tid;
uint num_iter = static_cast((threadgroup_end - thread_start + (threadgroup_size.x - 1)) / threadgroup_size.x);
const uint num_vecs_per_expert = args.num_vecs_per_expert;
const float scale0 = expert[gid.y * num_active_experts + 0].score;
const device float4* input0 = input + gid.y * num_vecs + thread_start;
const float scale1 = expert[gid.y * num_active_experts + 1].score;
const device float4* input1 = input0 + num_vecs_per_expert;
const float scale2 = expert[gid.y * num_active_experts + 2].score;
const device float4* input2 = input1 + num_vecs_per_expert;
const float scale3 = expert[gid.y * num_active_experts + 3].score;
const device float4* input3 = input2 + num_vecs_per_expert;
output += gid.y * num_vecs + thread_start;
for (; num_iter != 0; num_iter--) {
float4 acc = *output;
const float4 val0 = *input0;
const float4 val1 = *input1;
const float4 val2 = *input2;
const float4 val3 = *input3;
input0 += threadgroup_size.x;
acc = metal::fma(val0, scale0, acc);
input1 += threadgroup_size.x;
acc = metal::fma(val1, scale1, acc);
input2 += threadgroup_size.x;
acc = metal::fma(val2, scale2, acc);
input3 += threadgroup_size.x;
acc = metal::fma(val3, scale3, acc);
*output = acc;
output += threadgroup_size.x;
}
}
================================================
FILE: gpt_oss/metal/source/context.c
================================================
#include
#include
#include
#include
#include
#include
#include
#include
#include "internal/datatype.h"
#include "internal/model.h"
#include "internal/metal.h"
#include "internal/metal-kernels.h"
#include "internal/log.h"
#include "internal/rng.h"
enum gptoss_status GPTOSS_ABI gptoss_context_create(
gptoss_model_t model,
size_t context_length,
size_t max_batch_tokens,
gptoss_context_t* context_out)
{
*context_out = NULL;
enum gptoss_status status = gptoss_status_success;
struct gptoss_context* context = NULL;
// Validate context_length
if (context_length == 0) {
context_length = model->context_length;
} else if (context_length > model->context_length) {
GPTOSS_LOG_ERROR("requested context length %zu exceeds model context length %" PRIu32,
context_length, model->context_length);
status = gptoss_status_invalid_argument;
goto cleanup;
}
assert(context_length != 0);
assert(context_length <= model->context_length);
// Validate max_batch_tokens
if (max_batch_tokens == 0) {
max_batch_tokens = GPTOSS_DEFAULT_BATCH_SIZE;
} else if (max_batch_tokens > context_length) {
GPTOSS_LOG_ERROR("requested max batch tokens %zu exceeds context length %zu",
max_batch_tokens, context_length);
status = gptoss_status_invalid_argument;
goto cleanup;
}
assert(max_batch_tokens != 0);
assert(max_batch_tokens <= context_length);
context = malloc(sizeof(struct gptoss_context));
if (context == NULL) {
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for Context object",
sizeof(struct gptoss_context));
status = gptoss_status_insufficient_memory;
goto cleanup;
}
memset(context, 0, sizeof(struct gptoss_context));
atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed);
context->max_tokens = context_length;
context->max_batch_tokens = max_batch_tokens;
// Activation buffers
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->residual_activation_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->rmsnorm_activation_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &context->qkv_activation_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &context->sdpa_activation_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(float), NULL, &context->gate_activation_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &context->expert_activation_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
// The last entry will hold the total number of tokens.
status = gptoss_metal_buffer_create(&model->device, (1 + model->num_experts) * sizeof(uint32_t), NULL, &context->expert_offset_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * sizeof(uint32_t), NULL, &context->token_to_expert_routing_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->swiglu_input_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->moe_activation_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
// Input/output buffers
status = gptoss_metal_buffer_create(&model->device, sizeof(struct gptoss_control), NULL, &context->control_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->score_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->prob_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->max_threadgroups * sizeof(float), NULL, &context->sum_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * sizeof(uint64_t), NULL, &context->argmax_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_buffer_create(&model->device, model->num_blocks * context_length * 2 * model->num_kv_heads * model->head_dim * sizeof(float), NULL, &context->kvcache_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
context->kvcache_size = context->kvcache_buffer.size;
context->allocation_size =
context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size +
context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size +
context->gate_activation_buffer.size + context->expert_activation_buffer.size +
context->expert_offset_buffer.size + context->token_to_expert_routing_buffer.size + context->swiglu_input_buffer.size +
context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +
context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
context->model = model;
gptoss_model_retain(model);
*context_out = context;
context = NULL;
cleanup:
gptoss_context_release(context);
return status;
}
enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
gptoss_context_t context,
size_t* num_tokens_out)
{
*num_tokens_out = context->num_tokens;
return gptoss_status_success;
}
enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
gptoss_context_t context,
size_t* max_tokens_out)
{
*max_tokens_out = context->max_tokens;
return gptoss_status_success;
}
enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
gptoss_context_t context,
uint32_t* tokens_out,
size_t max_tokens,
size_t* num_tokens_out)
{
*num_tokens_out = context->num_tokens;
if (max_tokens < context->num_tokens) {
return gptoss_status_insufficient_memory;
}
if (context->num_tokens != 0) {
memcpy(tokens_out, context->token_buffer.ptr, context->num_tokens * sizeof(uint32_t));
}
return gptoss_status_success;
}
// Prefill: input_tokens_offset = number of tokens in KV cache, num_input_tokens > 0, num_output_tokens = 0.
// Sampling: input_tokens_offset = number of tokens in the context - 1, num_input_tokens = 1, num_output_tokens = 1.
// Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens.
static enum gptoss_status process_tokens(
gptoss_context_t context,
struct gptoss_metal_command_buffer* command_buffer,
size_t input_tokens_offset,
size_t num_input_tokens,
size_t num_output_tokens)
{
assert(num_input_tokens != 0);
assert(num_input_tokens <= context->max_batch_tokens);
assert(num_output_tokens <= context->max_batch_tokens);
assert(num_input_tokens >= num_output_tokens);
const size_t min_tokens_for_dense_matmul_kernels = 64;
const size_t min_tokens_for_dense_moe_kernels = 64;
enum gptoss_status status = gptoss_status_success;
const struct gptoss_model* model = context->model;
const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
const size_t input_tokens_end = input_tokens_offset + num_input_tokens;
for (size_t input_batch_start = input_tokens_offset;
input_batch_start < input_tokens_end;
input_batch_start += context->max_batch_tokens)
{
const size_t input_batch_size = math_min(context->max_batch_tokens, input_tokens_end - input_batch_start);
const size_t input_batch_end = input_batch_start + input_batch_size;
const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end);
status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
command_buffer,
&model->bf16_f32_embeddings_fn,
model->embeddings_threadgroup_size,
&context->token_buffer,
input_batch_start * sizeof(uint32_t),
&model->shared_weight_buffer,
/*weight_offset=*/0,
&context->residual_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
/*num_tokens=*/input_batch_size,
/*num_channels=*/model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch");
return status;
}
for (uint32_t n = 0; n < model->num_blocks; n++) {
const bool last_block = n + 1 == model->num_blocks;
const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size;
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
command_buffer,
&model->f32_bf16w_rmsnorm_fn,
&context->residual_activation_buffer,
/*input_offset=*/0,
&model->shared_weight_buffer,
/*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
&context->rmsnorm_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
/*num_tokens=*/input_batch_size,
/*num_channels=*/model->embedding_dim,
model->rmsnorm_epsilon);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
return status;
}
if (input_batch_size >= min_tokens_for_dense_matmul_kernels) {
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
command_buffer,
&model->f32_bf16w_dense_matmul_qkv_fn,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&model->shared_weight_buffer,
/*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
&model->shared_weight_buffer,
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
&context->qkv_activation_buffer,
/*output_offset=*/0,
&context->kvcache_buffer,
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
&context->control_buffer,
/*control_offset=*/0,
/*num_tokens=*/input_batch_size,
/*num_cols=*/model->embedding_dim,
/*num_rows=*/attn_qkv_dim,
/*max_tokens=*/context->max_tokens,
/*token_offset=*/input_batch_start);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_qkv kernel launch");
return status;
}
status = gptoss_metal_command_buffer_encode_launch_f32_rope(
command_buffer,
&model->f32_rope_fn,
/*threadgroup_size=*/32,
&context->qkv_activation_buffer,
/*input_offset=*/0,
&context->kvcache_buffer,
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
&context->control_buffer,
/*control_offset=*/0,
model->rope_theta,
model->interpolation_scale,
model->yarn_offset,
model->yarn_scale,
model->yarn_multiplier,
input_batch_size,
model->num_heads,
model->num_kv_heads,
model->head_dim,
/*max_tokens=*/context->max_tokens,
/*token_offset=*/input_batch_start);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
return status;
}
} else {
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
command_buffer,
&model->f32_bf16w_matmul_qkv_fn,
model->attn_qkv_threadgroup_size,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&model->shared_weight_buffer,
/*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
&model->shared_weight_buffer,
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
&context->qkv_activation_buffer,
/*output_offset=*/0,
&context->kvcache_buffer,
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
&context->control_buffer,
/*control_offset=*/0,
/*num_tokens=*/input_batch_size,
/*num_cols=*/model->embedding_dim,
/*num_q_heads=*/model->num_heads,
/*num_kv_heads=*/model->num_kv_heads,
/*attn_head_dim=*/model->head_dim,
/*token_offset=*/input_batch_start,
/*max_tokens=*/context->max_tokens,
/*rope_base=*/model->rope_theta,
/*interpolation_scale=*/model->interpolation_scale,
/*yarn_offset=*/model->yarn_offset,
/*yarn_scale=*/model->yarn_scale,
/*yarn_multiplier=*/model->yarn_multiplier);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch");
return status;
}
}
if (num_block_output_tokens != 0) {
status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
command_buffer,
&model->f32_sdpa_q8_d64_fn,
&context->qkv_activation_buffer,
/*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
&context->kvcache_buffer,
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
&model->shared_weight_buffer,
/*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
&context->sdpa_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
/*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
/*kv_stride=*/2 * context->max_tokens * model->head_dim,
num_block_output_tokens,
input_batch_start + input_batch_size - num_block_output_tokens,
model->num_heads, model->num_kv_heads, model->head_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch");
return status;
}
if (input_batch_size >= min_tokens_for_dense_matmul_kernels) {
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
command_buffer,
&model->f32_bf16w_dense_matmul_attn_output_fn,
&context->sdpa_activation_buffer,
/*input_offset=*/0,
&model->shared_weight_buffer,
/*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
&model->shared_weight_buffer,
/*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
&context->residual_activation_buffer,
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
&context->control_buffer,
/*control_offset=*/0,
/*num_tokens=*/num_block_output_tokens,
/*num_cols=*/model->num_heads * model->head_dim,
/*num_rows=*/model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_attn_output kernel launch");
return status;
}
} else {
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
command_buffer,
&model->f32_bf16w_matmul_fn,
model->attn_out_threadgroup_size,
&context->sdpa_activation_buffer,
/*input_offset=*/0,
&model->shared_weight_buffer,
/*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
&model->shared_weight_buffer,
/*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
&context->residual_activation_buffer,
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
&context->control_buffer,
/*control_offset=*/0,
/*num_tokens=*/num_block_output_tokens,
/*num_cols=*/model->num_heads * model->head_dim,
/*num_rows=*/model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch");
return status;
}
}
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
command_buffer,
&model->f32_bf16w_rmsnorm_fn,
&context->residual_activation_buffer,
/*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
&model->shared_weight_buffer,
/*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
&context->rmsnorm_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
num_block_output_tokens,
model->embedding_dim,
model->rmsnorm_epsilon);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
return status;
}
if (input_batch_size >= min_tokens_for_dense_matmul_kernels) {
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
command_buffer,
&model->f32_bf16w_dense_matmul_mlp_gate_fn,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&model->shared_weight_buffer,
/*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
&model->shared_weight_buffer,
/*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
&context->gate_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
num_block_output_tokens,
model->embedding_dim,
model->num_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_mlp_gate kernel launch");
return status;
}
} else {
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
command_buffer,
&model->f32_bf16w_matmul_fn,
model->mlp_gate_threadgroup_size,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&model->shared_weight_buffer,
/*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
&model->shared_weight_buffer,
/*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
&context->gate_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
/*num_tokens=*/num_block_output_tokens,
/*num_cols=*/model->embedding_dim,
/*num_rows=*/model->num_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
return status;
}
}
const char* kernel_name = NULL;
switch (model->num_experts) {
case 32:
kernel_name = "f32_topk_softmax_e32_k4_fn";
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
command_buffer,
&model->f32_topk_softmax_e32_k4_fn,
&context->gate_activation_buffer, /*input_offset=*/0,
&context->expert_activation_buffer, /*output_offset=*/0,
&context->control_buffer, /*control_offset=*/0,
num_block_output_tokens,
model->num_experts,
model->num_active_experts);
break;
case 128:
kernel_name = "f32_topk_softmax_e128_k4_fn";
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
command_buffer,
&model->f32_topk_softmax_e128_k4_fn,
&context->gate_activation_buffer, /*input_offset=*/0,
&context->expert_activation_buffer, /*output_offset=*/0,
&context->control_buffer, /*control_offset=*/0,
num_block_output_tokens,
model->num_experts,
model->num_active_experts);
break;
default:
status = gptoss_status_unsupported_argument;
GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts);
return status;
}
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name);
return status;
}
// If we have enough tokens in prefill, we will pick the prefill-optimized kernels.
if (num_block_output_tokens >= min_tokens_for_dense_moe_kernels) {
status = gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
command_buffer,
&model->f32_expert_routing_metadata_fn,
&context->expert_activation_buffer,
/*expert_predictions_offset=*/0,
&context->expert_offset_buffer,
/*expert_offsets_offset=*/0,
&context->token_to_expert_routing_buffer,
/*intra_expert_offsets_offset=*/0,
num_block_output_tokens * model->num_active_experts,
model->num_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_expert_routing_metadata kernel launch");
return status;
}
status = gptoss_metal_command_buffer_encode_launch_f32_scatter(
command_buffer,
&model->f32_scatter_e4_fn,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_predictions_offset=*/0,
&context->expert_offset_buffer,
/*expert_offsets_offset=*/0,
&context->token_to_expert_routing_buffer,
/*intra_expert_offsets_offset=*/0,
&context->swiglu_input_buffer,
/*output_offset=*/0,
model->embedding_dim,
num_block_output_tokens,
model->num_active_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_scatter kernel launch");
return status;
}
// Dense MoE SwiGLU matmul.
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
command_buffer,
&model->f32_mf4w_moe_dense_matmul_swiglu_fn,
&context->expert_offset_buffer,
/*expert_offsets_offset=*/0,
&context->swiglu_input_buffer,
/*input_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/0,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_swiglu_bias_offset,
&context->swiglu_activation_buffer,
/*output_offset=*/0,
model->swiglu_limit,
/*expert_stride_bytes=*/model->per_expert_block_weight_size,
num_block_output_tokens,
model->num_experts,
model->embedding_dim,
2 * model->mlp_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
return status;
}
// Dense MoE proj matmul.
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
command_buffer,
&model->f32_mf4w_moe_dense_matmul_fn,
&context->expert_offset_buffer,
/*expert_offsets_offset=*/0,
&context->swiglu_activation_buffer,
/*input_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/model->mlp_out_block_offset,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_out_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_out_bias_offset,
&context->moe_activation_buffer,
/*output_offset=*/0,
/*expert_stride_bytes=*/model->per_expert_block_weight_size,
num_block_output_tokens,
model->num_experts,
model->mlp_dim,
model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
return status;
}
// Gather and accumulate.
status = gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
command_buffer,
&model->f32_gather_and_accumulate_e4_fn,
&context->moe_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_predictions_offset=*/0,
&context->expert_offset_buffer,
/*expert_offsets_offset=*/0,
&context->token_to_expert_routing_buffer,
/*intra_expert_offsets_offset=*/0,
&context->residual_activation_buffer,
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
model->embedding_dim,
num_block_output_tokens,
model->num_active_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_gather_and_accumulate_e4 kernel launch");
return status;
}
} else {
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
command_buffer,
&model->f32_mf4w_moe_matmul_swiglu_fn,
model->mlp_swiglu_threadgroup_size,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/0,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_swiglu_bias_offset,
&context->swiglu_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
model->swiglu_limit,
model->per_expert_block_weight_size,
num_block_output_tokens,
model->num_active_experts,
model->embedding_dim,
model->mlp_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
return status;
}
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
command_buffer,
&model->f32_mf4w_moe_matmul_fn,
model->mlp_out_threadgroup_size,
&context->swiglu_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_offset=*/0,
&model->block_weight_buffers[n],
/*weight_block_offset=*/model->mlp_out_block_offset,
&model->block_weight_buffers[n],
/*weight_scale_offset=*/model->mlp_out_scale_offset,
&model->block_weight_buffers[n],
/*bias_offset=*/model->mlp_out_bias_offset,
&context->moe_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
model->per_expert_block_weight_size,
num_block_output_tokens,
model->num_active_experts,
model->mlp_dim,
model->embedding_dim);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
return status;
}
status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
command_buffer,
&model->f32_accumulate_e4_fn,
model->mlp_acc_threadgroup_size,
model->max_threadgroups,
&context->moe_activation_buffer,
/*input_offset=*/0,
&context->expert_activation_buffer,
/*expert_offset=*/0,
&context->residual_activation_buffer,
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
&context->control_buffer,
/*control_offset=*/0,
model->embedding_dim,
num_block_output_tokens,
model->num_active_experts);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
return status;
}
}
}
}
if (output_batch_size != 0) {
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
command_buffer,
&model->f32_bf16w_rmsnorm_fn,
&context->residual_activation_buffer,
/*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float),
&model->shared_weight_buffer,
/*weight_offset=*/model->rmsnorm_weight_offset,
&context->rmsnorm_activation_buffer,
/*output_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
/*num_tokens=*/output_batch_size,
/*num_channels=*/model->embedding_dim,
model->rmsnorm_epsilon);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
return status;
}
status = gptoss_metal_command_buffer_encode_fill_buffer(
command_buffer,
&context->argmax_buffer,
/*offset=*/0,
/*size=*/sizeof(uint64_t) * output_batch_size,
/*fill_value=*/0xFF);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode fill buffer command");
return status;
}
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
command_buffer,
&model->f32_bf16w_unembedding_fn,
model->unembedding_threadgroup_size,
model->max_threadgroups,
&context->rmsnorm_activation_buffer,
/*input_offset=*/0,
&model->shared_weight_buffer,
/*weight_offset=*/model->unembedding_weight_offset,
&context->score_buffer,
/*output_offset=*/0,
&context->argmax_buffer,
/*argmax_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
/*num_tokens=*/output_batch_size,
/*num_cols=*/model->embedding_dim,
/*num_rows=*/model->vocabulary_size);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch");
return status;
}
}
}
return gptoss_status_success;
}
enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
gptoss_context_t context,
const char* text,
size_t text_length,
size_t* num_tokens_out)
{
enum gptoss_status status = gptoss_status_success;
const struct gptoss_model* model = context->model;
const struct gptoss_tokenizer* tokenizer = model->tokenizer;
size_t num_appended_tokens = 0;
while (text_length != 0) {
if (context->num_tokens == context->max_tokens) {
status = gptoss_status_context_overflow;
break;
}
const char* tokens = tokenizer->tokens_ptr;
uint32_t best_token = UINT32_MAX;
uint32_t best_token_length = 0;
for (size_t t = 0; t < tokenizer->num_text_tokens; t++) {
uint16_t token_length;
memcpy(&token_length, tokens, sizeof(uint16_t));
tokens += sizeof(uint16_t);
if (token_length <= text_length && token_length > best_token_length) {
if (memcmp(text, tokens, token_length) == 0) {
if (token_length > best_token_length) {
best_token = (uint32_t) t;
best_token_length = token_length;
}
}
}
tokens += token_length;
}
if (best_token == UINT32_MAX) {
GPTOSS_LOG_ERROR("failed to tokenize text \"%.*s\"", (int) text_length, text);
return gptoss_status_invalid_argument;
}
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
if (context->num_kv_tokens > context->num_tokens) {
if (input_tokens[context->num_tokens] != best_token) {
input_tokens[context->num_tokens] = best_token;
// Invalidate the KV cache starting with the newly added token.
context->num_kv_tokens = context->num_tokens;
}
context->num_tokens++;
} else {
input_tokens[context->num_tokens++] = best_token;
}
num_appended_tokens++;
text += best_token_length;
text_length -= best_token_length;
}
if (num_tokens_out != NULL) {
*num_tokens_out = num_appended_tokens;
}
return status;
}
enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
gptoss_context_t context,
size_t num_tokens,
const uint32_t* tokens)
{
const struct gptoss_model* model = context->model;
// Validate all tokens
for (size_t t = 0; t < num_tokens; t++) {
const uint32_t token = tokens[t];
if (token >= model->vocabulary_size) {
GPTOSS_LOG_ERROR("token %" PRIu32 " at index %zu is out of bounds for vocabulary size %" PRIu32,
token, t, context->model->vocabulary_size);
return gptoss_status_invalid_argument;
}
}
enum gptoss_status status = gptoss_status_success;
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
while (num_tokens != 0) {
if (context->num_tokens == context->max_tokens) {
status = gptoss_status_context_overflow;
break;
}
if (context->num_kv_tokens > context->num_tokens) {
const size_t num_tokens_to_verify = math_min(context->num_kv_tokens - context->num_tokens, num_tokens);
size_t num_verified_tokens = 0;
for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) {
if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) {
// Invalidate the KV cache starting with the newly added tokens.
context->num_kv_tokens = context->num_tokens + num_verified_tokens;
break;
}
}
context->num_tokens += num_verified_tokens;
tokens += num_verified_tokens;
num_tokens -= num_verified_tokens;
} else {
const size_t num_tokens_to_copy = math_min(context->max_tokens - context->num_tokens, num_tokens);
memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t));
context->num_tokens += num_tokens_to_copy;
tokens += num_tokens_to_copy;
num_tokens -= num_tokens_to_copy;
}
}
return status;
}
enum gptoss_status GPTOSS_ABI gptoss_context_process(
gptoss_context_t context)
{
if (context->num_tokens > context->num_kv_tokens) {
struct gptoss_metal_command_buffer command_buffer = {0};
enum gptoss_status status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
control->abort = 0;
status = process_tokens(
context,
&command_buffer,
/*input_tokens_offset=*/context->num_kv_tokens,
/*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
/*num_output_tokens=*/0);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_command_buffer_commit(&command_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
status = gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
if (status != gptoss_status_success) {
goto cleanup;
}
context->num_kv_tokens = context->num_tokens;
cleanup:
gptoss_metal_command_buffer_release(&command_buffer);
return status;
}
return gptoss_status_success;
}
enum gptoss_status GPTOSS_ABI gptoss_context_sample(
gptoss_context_t context,
float temperature,
uint64_t seed,
size_t max_tokens,
uint32_t* tokens_out,
size_t* num_tokens_out)
{
enum gptoss_status status = gptoss_status_success;
const struct gptoss_model* model = context->model;
struct gptoss_metal_command_buffer command_buffer = {0};
*num_tokens_out = 0;
const uint32_t num_original_tokens = context->num_tokens;
status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
if (status != gptoss_status_success) {
goto cleanup;
}
struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
control->abort = 0;
for (size_t t = 0; t < max_tokens; t++) {
if (context->num_kv_tokens < context->num_tokens) {
status = process_tokens(
context,
&command_buffer,
/*input_tokens_offset=*/context->num_kv_tokens,
/*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
/*num_output_tokens=*/1);
context->num_kv_tokens = context->num_tokens;
} else {
status = process_tokens(
context,
&command_buffer,
/*input_tokens_offset=*/context->num_tokens - 1,
/*num_input_tokens=*/1,
/*num_output_tokens=*/1);
}
if (status != gptoss_status_success) {
goto cleanup;
}
if (temperature != 0.0f) {
assert(context->num_processed_tokens != 0);
uint32_t num_threadgroups = 0;
uint32_t num_dims_per_threadgroup = 0;
status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
&command_buffer,
&model->f32_softmax_fn,
/*threadgroup_size=*/512,
model->max_threadgroups,
&context->score_buffer,
/*score_offset=*/0,
&context->argmax_buffer,
/*argmax_offset=*/0,
&context->prob_buffer,
/*prob_offset=*/0,
&context->sum_buffer,
/*sum_offset=*/0,
&context->control_buffer,
/*control_offset=*/0,
model->vocabulary_size,
/*num_tokens=*/1,
temperature,
&num_threadgroups,
&num_dims_per_threadgroup);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
goto cleanup;
}
status = gptoss_metal_command_buffer_encode_launch_f32_sample(
&command_buffer,
&model->f32_sample_fn,
/*min_threadgroup_size=*/512,
&context->prob_buffer,
/*prob_offset=*/0,
&context->sum_buffer,
/*sum_offset=*/0,
&context->token_buffer,
/*token_offset=*/context->num_tokens * sizeof(uint32_t),
&context->control_buffer,
/*control_offset=*/0,
/*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),
/*rng_offset=*/context->num_tokens,
/*num_blocks=*/num_threadgroups,
/*num_channels=*/model->vocabulary_size,
/*num_channels_per_block=*/num_dims_per_threadgroup);
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch");
goto cleanup;
}
} else {
status = gptoss_metal_command_buffer_encode_copy_buffer(
&command_buffer,
&context->argmax_buffer,
/*input_offset=*/0,
&context->token_buffer,
/*output_offset=*/context->num_tokens * sizeof(uint32_t),
/*size=*/sizeof(uint32_t));
if (status != gptoss_status_success) {
GPTOSS_LOG_ERROR("failed to encode copy buffer");
goto cleanup;
}
}
context->num_tokens += 1;
context->num_kv_tokens = context->num_tokens;
}
gptoss_metal_command_buffer_commit(&command_buffer);
gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
const uint32_t* token_ptr = (const uint32_t*) context->token_buffer.ptr;
const uint32_t num_generated_tokens = context->num_tokens - num_original_tokens;
memcpy(tokens_out, token_ptr + num_original_tokens, num_generated_tokens * sizeof(uint32_t));
*num_tokens_out = num_generated_tokens;
cleanup:
gptoss_metal_command_buffer_release(&command_buffer);
return status;
}
enum gptoss_status GPTOSS_ABI gptoss_context_reset(
gptoss_context_t context)
{
context->num_tokens = 0;
// Note: context->num_kv_tokens is not reset and context->input_tokens_buffer is not cleared.
// If the subsequently added tokens match the tokens already in the KV cache, we reuse the KV cache.
return gptoss_status_success;
}
enum gptoss_status GPTOSS_ABI gptoss_context_retain(
gptoss_context_t context)
{
atomic_fetch_add_explicit(&context->ref_count, 1, memory_order_relaxed);
return gptoss_status_success;
}
enum gptoss_status GPTOSS_ABI gptoss_context_release(
gptoss_context_t context)
{
if (context != NULL) {
if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) {
// Activation buffers
gptoss_metal_buffer_release(&context->residual_activation_buffer);
gptoss_metal_buffer_release(&context->rmsnorm_activation_buffer);
gptoss_metal_buffer_release(&context->qkv_activation_buffer);
gptoss_metal_buffer_release(&context->sdpa_activation_buffer);
gptoss_metal_buffer_release(&context->gate_activation_buffer);
gptoss_metal_buffer_release(&context->expert_activation_buffer);
gptoss_metal_buffer_release(&context->swiglu_activation_buffer);
gptoss_metal_buffer_release(&context->moe_activation_buffer);
gptoss_metal_buffer_release(&context->expert_offset_buffer);
gptoss_metal_buffer_release(&context->token_to_expert_routing_buffer);
gptoss_metal_buffer_release(&context->swiglu_input_buffer);
// Input/output buffers
gptoss_metal_buffer_release(&context->control_buffer);
gptoss_metal_buffer_release(&context->token_buffer);
gptoss_metal_buffer_release(&context->score_buffer);
gptoss_metal_buffer_release(&context->prob_buffer);
gptoss_metal_buffer_release(&context->sum_buffer);
gptoss_metal_buffer_release(&context->argmax_buffer);
gptoss_metal_buffer_release(&context->kvcache_buffer);
gptoss_model_release(context->model);
memset(context, 0, sizeof(struct gptoss_context));
free(context);
}
}
return gptoss_status_success;
}
================================================
FILE: gpt_oss/metal/source/convert.metal
================================================
#include
#include
#pragma METAL fp math_mode(safe)
#pragma METAL fp contract(off)
kernel void gptoss_mf4_f32_convert(
constant gptoss_convert_args& args [[ buffer(0) ]],
const device uint4* blocks [[ buffer(1) ]],
const device uchar* scales [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint threadgroup_size [[ threads_per_threadgroup ]])
{
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
const ulong thread_start = threadgroup_start + tid;
uint num_iter = static_cast((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
blocks += thread_start;
scales += thread_start;
output += 8 * thread_start;
for (; num_iter != 0; num_iter--) {
const uint4 block = *blocks;
const float scale = as_type((static_cast(*scales) + 14) << 23);
uint4 block02468ACEGIKMOQSU = block + block;
uint4 block13579BDFHJLNPRTV = block >> 3;
block02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
block13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
block02468ACEGIKMOQSU += 0x70707070u;
block13579BDFHJLNPRTV += 0x70707070u;
block02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
block13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
const uint4 block26AEIMQU = block02468ACEGIKMOQSU & 0xFF00FF00u;
const uint4 block048CGKOS = (block02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
const uint4 block37BFJNRV = block13579BDFHJLNPRTV & 0xFF00FF00u;
const uint4 block159DHLPT = (block13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
const float4 block048C = static_cast(as_type(block048CGKOS.xy)) * scale;
const float4 blockGKOS = static_cast(as_type(block048CGKOS.zw)) * scale;
const float4 block26AE = static_cast(as_type(block26AEIMQU.xy)) * scale;
const float4 blockIMQU = static_cast(as_type(block26AEIMQU.zw)) * scale;
const float4 block159D = static_cast(as_type(block159DHLPT.xy)) * scale;
const float4 blockHLPT = static_cast(as_type(block159DHLPT.zw)) * scale;
const float4 block37BF = static_cast(as_type(block37BFJNRV.xy)) * scale;
const float4 blockJNRV = static_cast(as_type(block37BFJNRV.zw)) * scale;
output[0] = (float4) { block048C.x, block159D.x, block26AE.x, block37BF.x };
output[1] = (float4) { block048C.y, block159D.y, block26AE.y, block37BF.y };
output[2] = (float4) { block048C.z, block159D.z, block26AE.z, block37BF.z };
output[3] = (float4) { block048C.w, block159D.w, block26AE.w, block37BF.w };
output[4] = (float4) { blockGKOS.x, blockHLPT.x, blockIMQU.x, blockJNRV.x };
output[5] = (float4) { blockGKOS.y, blockHLPT.y, blockIMQU.y, blockJNRV.y };
output[6] = (float4) { blockGKOS.z, blockHLPT.z, blockIMQU.z, blockJNRV.z };
output[7] = (float4) { blockGKOS.w, blockHLPT.w, blockIMQU.w, blockJNRV.w };
blocks += threadgroup_size;
scales += threadgroup_size;
output += 8 * threadgroup_size;
}
}
================================================
FILE: gpt_oss/metal/source/embeddings.metal
================================================
#include
#pragma METAL fp math_mode(safe)
#pragma METAL fp contract(off)
kernel void gptoss_bf16_f32_embeddings(
constant gptoss_embeddings_args& args [[ buffer(0) ]],
const device uint* tokens [[ buffer(1) ]],
const device bfloat4* weights [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
const device gptoss_control* control [[ buffer(4) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint threadgroup_size [[ threads_per_threadgroup ]])
{
if (control->abort != 0) {
return;
}
const uint t = tokens[gid];
weights += t * args.num_vecs;
output += gid * args.num_vecs;
for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
const bfloat4 w = weights[i];
output[i] = static_cast(w);
}
}
================================================
FILE: gpt_oss/metal/source/expert_routing_metadata.metal
================================================
#include
#include
#include
#include
constant uint kMaxExperts = 128;
kernel void gptoss_f32_expert_routing_metadata(
constant gptoss_expert_routing_metadata_args& args [[ buffer(0) ]],
const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(1) ]],
device uint* __restrict__ expert_offsets [[ buffer(2) ]],
device uint* __restrict__ intra_expert_offsets [[ buffer(3) ]],
uint tg_size [[threads_per_threadgroup]],
uint tid [[thread_position_in_threadgroup]])
{
assert(args.num_experts <= kMaxExperts);
// Create threadgroup mem and initialize it to 0.
threadgroup metal::atomic_uint tg_counts[kMaxExperts];
for (uint e = tid; e < args.num_experts; e += tg_size) {
metal::atomic_store_explicit(&tg_counts[e], 0u, metal::memory_order_relaxed);
}
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
for (uint i = tid; i < args.tokens; i += tg_size) {
const uint e = expert_predictions[i].expert_id;
const uint r = metal::atomic_fetch_add_explicit(&tg_counts[e], 1u, metal::memory_order_relaxed);
intra_expert_offsets[i] = r;
}
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
if (tid == 0) {
uint total = 0;
for (uint e = 0; e < args.num_experts; ++e) {
const uint bin = metal::atomic_load_explicit(&tg_counts[e], metal::memory_order_relaxed);
expert_offsets[e] = total;
total += bin;
}
expert_offsets[args.num_experts] = total;
}
}
================================================
FILE: gpt_oss/metal/source/gather_and_accumulate.metal
================================================
#include
#include
#include
#include
// TODO(ibrahim): This is not optimal as each thread only gathers and accumulates a single float4. To amortize the
// cost of reading the expert, offset and scales for a token, we should let each thread gather and accumulate several
// float4s.
kernel void gptoss_f32_gather_and_accumulate_e4(
constant gptoss_gather_args& args [[ buffer(0) ]],
const device float* in [[ buffer(1) ]],
const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(2) ]],
const device uint* expert_offsets [[ buffer(3) ]],
const device uint* intra_expert_offsets [[ buffer(4) ]],
device float* out [[ buffer(5) ]],
uint3 gid [[thread_position_in_grid]])
{
const uint T = args.tokens;
const uint k = args.active_experts_per_token;
const uint D = args.token_stride;
assert((D & 3u) == 0);
assert(k == 4);
const uint row = gid.y;
if (row >= T) {
return;
}
const uint col_vec4 = gid.x;
const uint col = col_vec4 * 4u;
if (col >= D) {
return;
}
device float4* dst4 = reinterpret_cast(out + row * D + col);
const uint base = row * k;
const gptoss_expert_prediction expert0 = expert_predictions[base];
const gptoss_expert_prediction expert1 = expert_predictions[base + 1];
const gptoss_expert_prediction expert2 = expert_predictions[base + 2];
const gptoss_expert_prediction expert3 = expert_predictions[base + 3];
const uint expert0_id = expert0.expert_id;
const uint expert1_id = expert1.expert_id;
const uint expert2_id = expert2.expert_id;
const uint expert3_id = expert3.expert_id;
const float scale0 = expert0.score;
const float scale1 = expert1.score;
const float scale2 = expert2.score;
const float scale3 = expert3.score;
const uint4 current_intra_expert_offsets =
*reinterpret_cast(&intra_expert_offsets[base]);
// Get the row indices for the current expert ids
const uint r0 = expert_offsets[expert0_id] + current_intra_expert_offsets.x;
const uint r1 = expert_offsets[expert1_id] + current_intra_expert_offsets.y;
const uint r2 = expert_offsets[expert2_id] + current_intra_expert_offsets.z;
const uint r3 = expert_offsets[expert3_id] + current_intra_expert_offsets.w;
const device float4* src0 =
reinterpret_cast(in + r0 * D + col);
const device float4* src1 =
reinterpret_cast(in + r1 * D + col);
const device float4* src2 =
reinterpret_cast(in + r2 * D + col);
const device float4* src3 =
reinterpret_cast(in + r3 * D + col);
float4 acc = *dst4;
acc = metal::fma(*src0, scale0, acc);
acc = metal::fma(*src1, scale1, acc);
acc = metal::fma(*src2, scale2, acc);
acc = metal::fma(*src3, scale3, acc);
*dst4 = acc;
}
================================================
FILE: gpt_oss/metal/source/generate.c
================================================
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "internal/model.h"
struct {
atomic_uint_least64_t inference_bytes;
atomic_size_t num_prefill_tokens;
atomic_uint_least64_t prefill_microseconds;
atomic_size_t num_generated_tokens;
atomic_uint_least64_t generation_microseconds;
} globals = {
.inference_bytes = 0,
.num_prefill_tokens = 0,
.prefill_microseconds = 0,
.num_generated_tokens = 0,
.generation_microseconds = 0,
};
struct options {
const char* model;
const char* prompt;
size_t context_length;
size_t max_tokens;
float temperature;
bool verbose;
};
static inline double mach_timestamp_diff_to_seconds(uint64_t start_timestamp, uint64_t end_timestamp) {
static mach_timebase_info_data_t timebase_info = {0};
if (timebase_info.denom == 0) {
mach_timebase_info(&timebase_info);
}
const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
return ((double) elapsed_mach_time * (double) timebase_info.numer) / ((double) timebase_info.denom * 1.0e+9);
}
static inline uint64_t mach_timestamp_diff_to_microseconds(uint64_t start_timestamp, uint64_t end_timestamp) {
static mach_timebase_info_data_t timebase_info = {0};
if (timebase_info.denom == 0) {
mach_timebase_info(&timebase_info);
}
const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
const uint64_t denominator = timebase_info.denom * UINT64_C(1000);
return (elapsed_mach_time * timebase_info.numer + denominator / 2) / denominator;
}
static void print_usage(const char* program_name) {
printf("Usage: %s [-p ] [-n ]\n", program_name);
}
struct options parse_options(int argc, char** argv) {
struct options options = (struct options) {
.model = NULL,
.prompt = NULL,
.context_length = 0,
.max_tokens = 0,
.temperature = 0.0f,
.verbose = false,
};
if (argc < 2) {
fprintf(stderr, "Error: missing required command-line argument\n");
print_usage(argv[0]);
exit(EXIT_FAILURE);
}
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "--help") == 0) {
print_usage(argv[0]);
exit(EXIT_SUCCESS);
} else if (strcmp(argv[i], "-p") == 0 || strcmp(argv[i], "--prompt") == 0) {
if (i + 1 >= argc) {
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
print_usage(argv[0]);
exit(EXIT_FAILURE);
}
options.prompt = argv[++i];
} else if (strcmp(argv[i], "--context-length") == 0) {
if (i + 1 >= argc) {
fprintf(stderr, "Error: missing argument for --context-length\n");
print_usage(argv[0]);
exit(EXIT_FAILURE);
}
char* context_length_start = argv[++i];
char* context_length_end = context_length_start;
options.context_length = strtoul(context_length_start, &context_length_end, 10);
if (context_length_end == context_length_start || *context_length_end != 0) {
fprintf(stderr, "Error: failed to parse context length value \"%s\"\n", context_length_start);
exit(EXIT_FAILURE);
}
} else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--max-tokens") == 0) {
if (i + 1 >= argc) {
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
print_usage(argv[0]);
exit(EXIT_FAILURE);
}
char* max_tokens_start = argv[++i];
char* max_tokens_end = max_tokens_start;
options.max_tokens = strtoul(max_tokens_start, &max_tokens_end, 10);
if (max_tokens_end == max_tokens_start || *max_tokens_end != 0) {
fprintf(stderr, "Error: failed to max tokens value \"%s\"\n", max_tokens_start);
exit(EXIT_FAILURE);
}
if (options.max_tokens == 0) {
fprintf(stderr, "Error: invalid max tokens value %zu\n", options.max_tokens);
exit(EXIT_FAILURE);
}
} else if (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--temperature") == 0) {
if (i + 1 >= argc) {
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
print_usage(argv[0]);
exit(EXIT_FAILURE);
}
char* temperature_start = argv[++i];
char* temperature_end = temperature_start;
options.temperature = strtof(temperature_start, &temperature_end);
if (temperature_end == temperature_start || *temperature_end != 0) {
fprintf(stderr, "Error: failed to parse temperature value \"%s\"\n", temperature_start);
exit(EXIT_FAILURE);
}
if (signbit(options.temperature) != 0 || !(options.temperature <= 2.0f)) {
fprintf(stderr, "Error: invalid temperature value %f\n", options.temperature);
exit(EXIT_FAILURE);
}
} else if (strcmp(argv[i], "-v") == 0 || strcmp(argv[i], "--verbose") == 0) {
options.verbose = true;
} else {
if (options.model == NULL) {
options.model = argv[i];
} else {
fprintf(stderr, "Error: unexpected command-line argument %s\n", argv[i]);
print_usage(argv[0]);
exit(EXIT_FAILURE);
}
}
}
if (options.model == NULL) {
fprintf(stderr, "Error: missing required model argument\n");
print_usage(argv[0]);
exit(EXIT_FAILURE);
}
if (options.prompt == NULL) {
fprintf(stderr, "Error: missing required prompt argument\n");
print_usage(argv[0]);
exit(EXIT_FAILURE);
}
return options;
}
static void print_profile() {
const size_t num_prefill_tokens = atomic_load(&globals.num_prefill_tokens);
const uint64_t prefill_microseconds = atomic_load(&globals.prefill_microseconds);
const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens);
const uint64_t generation_microseconds = atomic_load(&globals.generation_microseconds);
const uint64_t inference_bytes = atomic_load(&globals.inference_bytes);
if (num_prefill_tokens != 0 || num_generated_tokens != 0) {
printf("\n");
}
if (num_prefill_tokens != 0) {
printf("Prefill speed (%zu tokens): %.1f tokens/second\n",
num_prefill_tokens,
(double) num_prefill_tokens / (double) prefill_microseconds * 1.0e+6);
}
if (num_generated_tokens != 0) {
printf("Generation speed (%zu tokens): %.1f tokens/second\n",
num_generated_tokens,
(double) num_generated_tokens / (double) generation_microseconds * 1.0e+6);
}
}
static void ctrl_c_handler(int signum) {
print_profile();
exit(EXIT_SUCCESS);
}
int main(int argc, char *argv[]) {
enum gptoss_status status;
gptoss_model_t model = NULL;
gptoss_tokenizer_t tokenizer = NULL;
gptoss_context_t context = NULL;
struct sigaction act;
act.sa_handler = ctrl_c_handler;
sigaction(SIGINT, &act, NULL);
setvbuf(stdout, NULL, _IONBF, 0);
struct options options = parse_options(argc, argv);
const uint64_t load_start_time = mach_continuous_time();
status = gptoss_model_create_from_file(options.model, &model);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to load model from file %s\n", options.model);
goto error;
}
size_t max_model_context_length = 0;
status = gptoss_model_get_max_context_length(model, &max_model_context_length);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to query maximum context length\n");
goto error;
}
assert(max_model_context_length != 0);
if (options.context_length == 0) {
options.context_length = max_model_context_length;
} else if (options.context_length > max_model_context_length) {
fprintf(stderr, "Error: context length %zu exceeds maximum context length %zu supported by the model\n", options.context_length, max_model_context_length);
goto error;
}
status = gptoss_model_get_tokenizer(model, &tokenizer);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to retrieve Tokenizer\n");
goto error;
}
uint32_t return_token_id = UINT32_MAX;
status = gptoss_tokenizer_get_special_token_id(tokenizer, gptoss_special_token_return, &return_token_id);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to query end-of-text token ID\n");
goto error;
}
status = gptoss_context_create(model, options.context_length, /*max_batch_tokens=*/0, &context);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to create Context object\n");
goto error;
}
if (options.verbose) {
printf("Model weights size: %.2lf MB\n", (double) model->weights_size * 0x1.0p-20);
printf("Model allocation size: %.2lf MB\n", (double) model->allocation_size * 0x1.0p-20);
printf("Context allocation size: %.2lf MB\n", (double) context->allocation_size * 0x1.0p-20);
printf(" Including KV cache: %.2lf MB\n", (double) context->kvcache_size * 0x1.0p-20);
}
const uint64_t load_end_time = mach_continuous_time();
const double load_elapsed_seconds = mach_timestamp_diff_to_seconds(load_start_time, load_end_time);
if (options.verbose) {
printf("Loaded model in %.3f seconds\n", load_elapsed_seconds);
}
const uint64_t prefill_start_time = mach_continuous_time();
size_t num_prefill_tokens = 0;
status = gptoss_context_append_chars(context, options.prompt, strlen(options.prompt), &num_prefill_tokens);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to tokenize prompt \"%s\"\n", options.prompt);
goto error;
}
atomic_store(&globals.num_prefill_tokens, num_prefill_tokens);
status = gptoss_context_process(context);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to process Context object\n");
goto error;
}
const uint64_t prefill_end_time = mach_continuous_time();
while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) {
uint32_t predicted_token = UINT32_MAX;
size_t num_predicted_tokens = 0;
const uint64_t inference_start_timestamp = mach_continuous_time();
status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, /*num_tokens=*/1, &predicted_token, &num_predicted_tokens);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to sample from the Context object\n");
goto error;
}
const uint64_t inference_end_timestamp = mach_continuous_time();
if (predicted_token == return_token_id) {
// Yield token -> stop generation
break;
}
// Unembedding: detokenize
size_t token_size = 0;
const void* token_ptr = NULL;
status = gptoss_tokenizer_decode(tokenizer, predicted_token, &token_ptr, &token_size);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to detokenize predicted token %" PRIu32 "\n", predicted_token);
goto error;
}
const size_t previous_num_generated_tokens = atomic_fetch_add(&globals.num_generated_tokens, 1);
if (previous_num_generated_tokens == 0) {
atomic_fetch_add(&globals.prefill_microseconds, mach_timestamp_diff_to_microseconds(prefill_start_time, prefill_end_time));
} else {
atomic_fetch_add(&globals.generation_microseconds, mach_timestamp_diff_to_microseconds(inference_start_timestamp, inference_end_timestamp));
}
printf("%.*s", (int) token_size, (const char*) token_ptr);
status = gptoss_context_append_tokens(context, 1, &predicted_token);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to append predicted token %" PRIu32 " to context\n", predicted_token);
goto error;
}
}
print_profile();
return EXIT_SUCCESS;
error:
gptoss_context_release(context);
gptoss_tokenizer_release(tokenizer);
gptoss_model_release(model);
return EXIT_FAILURE;
}
================================================
FILE: gpt_oss/metal/source/include/internal/datatype.h
================================================
#pragma once
#include
#include
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
GPTOSS_ALIGN(2) uint16_t bits;
} gptoss_bfloat16;
static_assert(sizeof(gptoss_bfloat16) == 2, "bfloat16 size is not 2 bytes");
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
GPTOSS_ALIGN(2) uint16_t bits;
} gptoss_float16;
static_assert(sizeof(gptoss_float16) == 2, "float16 size is not 2 bytes");
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
GPTOSS_ALIGN(1) uint8_t bits;
} gptoss_float8ue8m0;
static_assert(sizeof(gptoss_float8ue8m0) == 1, "gptoss_float8ue8m0 size is not 1 bytes");
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
GPTOSS_ALIGN(1) uint8_t bits;
} gptoss_float8e5m2;
static_assert(sizeof(gptoss_float8e5m2) == 1, "float8e5m2 size is not 1 bytes");
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
GPTOSS_ALIGN(1) uint8_t bits;
} gptoss_float8e4m3;
static_assert(sizeof(gptoss_float8e4m3) == 1, "gptoss_float8e4m3 size is not 1 bytes");
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
GPTOSS_ALIGN(1) uint8_t bits;
} gptoss_float4e2m1x2;
static_assert(sizeof(gptoss_float4e2m1x2) == 1, "gptoss_float4e2m1x2 size is not 1 bytes");
================================================
FILE: gpt_oss/metal/source/include/internal/datatype.hpp
================================================
#pragma once
#include
#include
namespace gptoss {
template
WideT upcast(NarrowT);
template <>
inline float upcast(gptoss_bfloat16 bf16_value) {
const uint32_t bits = static_cast(bf16_value.bits) << 16;
return std::bit_cast(bits);
}
template <>
inline float upcast(gptoss_float16 fp16_value) {
return static_cast(std::bit_cast<_Float16>(fp16_value.bits));
}
template <>
inline float upcast(gptoss_float8e4m3 fp8_value) {
static constexpr uint16_t fp8e4m3_to_fp32[256] = {
0x0000, 0x3B00, 0x3B80, 0x3BC0, 0x3C00, 0x3C20, 0x3C40, 0x3C60,
0x3C80, 0x3C90, 0x3CA0, 0x3CB0, 0x3CC0, 0x3CD0, 0x3CE0, 0x3CF0,
0x3D00, 0x3D10, 0x3D20, 0x3D30, 0x3D40, 0x3D50, 0x3D60, 0x3D70,
0x3D80, 0x3D90, 0x3DA0, 0x3DB0, 0x3DC0, 0x3DD0, 0x3DE0, 0x3DF0,
0x3E00, 0x3E10, 0x3E20, 0x3E30, 0x3E40, 0x3E50, 0x3E60, 0x3E70,
0x3E80, 0x3E90, 0x3EA0, 0x3EB0, 0x3EC0, 0x3ED0, 0x3EE0, 0x3EF0,
0x3F00, 0x3F10, 0x3F20, 0x3F30, 0x3F40, 0x3F50, 0x3F60, 0x3F70,
0x3F80, 0x3F90, 0x3FA0, 0x3FB0, 0x3FC0, 0x3FD0, 0x3FE0, 0x3FF0,
0x4000, 0x4010, 0x4020, 0x4030, 0x4040, 0x4050, 0x4060, 0x4070,
0x4080, 0x4090, 0x40A0, 0x40B0, 0x40C0, 0x40D0, 0x40E0, 0x40F0,
0x4100, 0x4110, 0x4120, 0x4130, 0x4140, 0x4150, 0x4160, 0x4170,
0x4180, 0x4190, 0x41A0, 0x41B0, 0x41C0, 0x41D0, 0x41E0, 0x41F0,
0x4200, 0x4210, 0x4220, 0x4230, 0x4240, 0x4250, 0x4260, 0x4270,
0x4280, 0x4290, 0x42A0, 0x42B0, 0x42C0, 0x42D0, 0x42E0, 0x42F0,
0x4300, 0x4310, 0x4320, 0x4330, 0x4340, 0x4350, 0x4360, 0x4370,
0x4380, 0x4390, 0x43A0, 0x43B0, 0x43C0, 0x43D0, 0x43E0, 0x7FF0,
0x8000, 0xBB00, 0xBB80, 0xBBC0, 0xBC00, 0xBC20, 0xBC40, 0xBC60,
0xBC80, 0xBC90, 0xBCA0, 0xBCB0, 0xBCC0, 0xBCD0, 0xBCE0, 0xBCF0,
0xBD00, 0xBD10, 0xBD20, 0xBD30, 0xBD40, 0xBD50, 0xBD60, 0xBD70,
0xBD80, 0xBD90, 0xBDA0, 0xBDB0, 0xBDC0, 0xBDD0, 0xBDE0, 0xBDF0,
0xBE00, 0xBE10, 0xBE20, 0xBE30, 0xBE40, 0xBE50, 0xBE60, 0xBE70,
0xBE80, 0xBE90, 0xBEA0, 0xBEB0, 0xBEC0, 0xBED0, 0xBEE0, 0xBEF0,
0xBF00, 0xBF10, 0xBF20, 0xBF30, 0xBF40, 0xBF50, 0xBF60, 0xBF70,
0xBF80, 0xBF90, 0xBFA0, 0xBFB0, 0xBFC0, 0xBFD0, 0xBFE0, 0xBFF0,
0xC000, 0xC010, 0xC020, 0xC030, 0xC040, 0xC050, 0xC060, 0xC070,
0xC080, 0xC090, 0xC0A0, 0xC0B0, 0xC0C0, 0xC0D0, 0xC0E0, 0xC0F0,
0xC100, 0xC110, 0xC120, 0xC130, 0xC140, 0xC150, 0xC160, 0xC170,
0xC180, 0xC190, 0xC1A0, 0xC1B0, 0xC1C0, 0xC1D0, 0xC1E0, 0xC1F0,
0xC200, 0xC210, 0xC220, 0xC230, 0xC240, 0xC250, 0xC260, 0xC270,
0xC280, 0xC290, 0xC2A0, 0xC2B0, 0xC2C0, 0xC2D0, 0xC2E0, 0xC2F0,
0xC300, 0xC310, 0xC320, 0xC330, 0xC340, 0xC350, 0xC360, 0xC370,
0xC380, 0xC390, 0xC3A0, 0xC3B0, 0xC3C0, 0xC3D0, 0xC3E0, 0xFFF0,
};
const gptoss_bfloat16 bf16_value{.bits = fp8e4m3_to_fp32[fp8_value.bits]};
return upcast(bf16_value);
}
template <>
inline double upcast(float fp32_value) {
return static_cast(fp32_value);
}
template <>
inline double upcast(gptoss_bfloat16 bf16_value) {
const float fp32_value = upcast(bf16_value);
return upcast(fp32_value);
}
template <>
inline double upcast(gptoss_float16 fp16_value) {
const float fp32_value = upcast(fp16_value);
return upcast(fp32_value);
}
template <>
inline double upcast(gptoss_float8e4m3 fp8_value) {
const float fp32_value = upcast(fp8_value);
return upcast(fp32_value);
}
} // namespace gptoss
================================================
FILE: gpt_oss/metal/source/include/internal/kernel-args.h
================================================
#pragma once
#if !defined(__METAL_VERSION__)
#include
#endif
// TODO(ibahmed): specalize using metal function constants.
#define QKV_Bm 64
#define QKV_Bn 64
#define QKV_Bk 32
#define QKV_Sg_Bm 32
#define QKV_Sg_Bn 32
#define ATTN_OUTPUT_Bm 32
#define ATTN_OUTPUT_Bn 64
#define ATTN_OUTPUT_Bk 64
#define ATTN_OUTPUT_Sg_Bm 32
#define ATTN_OUTPUT_Sg_Bn 16
#define MLP_GATE_Bm 64
#define MLP_GATE_Bn 16
#define MLP_GATE_Bk 64
#define MLP_GATE_Sg_Bm 16
#define MLP_GATE_Sg_Bn 16
#define MOE_DENSE_MATMUL_SWIGLU_Bm 32
#define MOE_DENSE_MATMUL_SWIGLU_Bn 64
#define MOE_DENSE_MATMUL_SWIGLU_Bk 16
#define MOE_DENSE_MATMUL_SWIGLU_Sg_Bm 32
#define MOE_DENSE_MATMUL_SWIGLU_Sg_Bn 16
#define MOE_DENSE_MATMUL_Bm 32
#define MOE_DENSE_MATMUL_Bn 64
#define MOE_DENSE_MATMUL_Bk 16
#define MOE_DENSE_MATMUL_Sg_Bm 32
#define MOE_DENSE_MATMUL_Sg_Bn 16
struct gptoss_expert_prediction {
uint32_t expert_id;
float score;
};
struct gptoss_control {
uint32_t abort;
};
struct gptoss_topk_args {
uint32_t num_vecs_per_token;
};
struct gptoss_sdpa_args {
uint32_t qkv_dim;
uint32_t num_kv_tokens;
uint32_t kv_stride;
uint32_t window;
};
struct gptoss_u32_fill_random_args {
uint64_t num_vecs_per_threadgroup;
uint64_t num_vecs;
uint64_t offset;
uint64_t seed;
};
struct gptoss_f32_fill_random_args {
uint64_t num_vecs_per_threadgroup;
uint64_t num_vecs;
uint64_t offset;
uint64_t seed;
float scale;
float bias;
};
struct gptoss_accumulate_args {
uint32_t num_vecs_per_expert;
uint32_t num_vecs_per_threadgroup;
uint32_t num_vecs;
};
struct gptoss_convert_args {
uint64_t num_vecs_per_threadgroup;
uint64_t num_vecs;
};
struct gptoss_embeddings_args {
uint32_t num_vecs;
};
struct gptoss_rmsnorm_args {
uint32_t num_vecs;
float num_channels;
float epsilon;
};
struct gptoss_matmul_args {
uint32_t num_column_vecs;
uint32_t num_rows;
uint32_t add;
};
struct gptoss_dense_matmul_args {
uint32_t m;
uint32_t n;
uint32_t k;
};
// Specialize qkv matmul args as it writes kv directly to the KV cache buffer.
struct gptoss_dense_matmul_qkv_args {
uint32_t m;
uint32_t n;
uint32_t k;
uint32_t max_tokens;
uint32_t token_offset;
};
struct gptoss_scatter_args {
uint32_t tokens;
uint32_t active_experts_per_token;
uint32_t token_stride;
};
struct gptoss_moe_dense_matmul_swiglu_args {
uint32_t k;
uint32_t n;
uint32_t weight_blocks_expert_stride_bytes;
uint32_t weight_scales_expert_stride_bytes;
uint32_t bias_expert_stride_bytes;
float swiglu_min;
float swiglu_max;
};
struct gptoss_moe_dense_matmul_args {
uint32_t k;
uint32_t n;
uint32_t weight_blocks_expert_stride_bytes;
uint32_t weight_scales_expert_stride_bytes;
uint32_t bias_expert_stride_bytes;
};
struct gptoss_expert_routing_metadata_args {
uint32_t tokens;
uint32_t num_experts;
};
struct gptoss_gather_args {
uint32_t tokens;
uint32_t active_experts_per_token;
uint32_t token_stride;
};
struct gptoss_unembedding_args {
uint32_t num_column_vecs;
uint32_t num_rows_per_threadgroup;
uint32_t num_rows;
};
struct gptoss_moe_matmul_swiglu_args {
uint32_t num_column_vecs;
uint32_t num_rows;
uint32_t num_active_experts;
uint32_t weight_expert_stride; // in bytes
uint32_t output_expert_stride; // in elements
float swiglu_min;
float swiglu_max;
};
struct gptoss_moe_matmul_args {
uint32_t num_column_vecs;
uint32_t num_rows;
uint32_t num_active_experts;
uint32_t input_expert_stride; // in blocks of 32 elements
uint32_t weight_expert_stride; // in bytes
uint32_t output_expert_stride; // in elements
};
struct gptoss_rope_args {
uint32_t token_stride;
uint32_t token_offset;
uint32_t max_tokens;
float freq_scale;
float interpolation_scale;
float yarn_offset;
float yarn_scale;
float yarn_multiplier;
};
struct gptoss_qkv_args {
uint32_t num_column_vecs;
uint32_t num_rows;
uint32_t token_offset;
float freq_scale;
float interpolation_scale;
float yarn_offset;
float yarn_scale;
float yarn_multiplier;
uint32_t max_tokens;
};
struct gptoss_softmax_args {
uint32_t num_vecs;
uint32_t num_vecs_per_threadgroup;
uint32_t max_threadgroups;
float temperature;
};
struct gptoss_sample_args {
uint64_t rng_seed;
uint32_t rng_offset;
uint32_t num_blocks;
uint32_t num_dims;
uint32_t num_dims_per_block;
};
================================================
FILE: gpt_oss/metal/source/include/internal/log.h
================================================
#pragma once
#include
void gptoss_format_log(const char* format, va_list args);
__attribute__((__format__(__printf__, 1, 2)))
inline static void gptoss_log(const char* format, ...) {
va_list args;
va_start(args, format);
gptoss_format_log(format, args);
va_end(args);
}
#define GPTOSS_LOG_ERROR(message, ...) \
gptoss_log("Error: " message "\n", ##__VA_ARGS__)
#define GPTOSS_LOG_WARNING(message, ...) \
gptoss_log("Warning: " message "\n", ##__VA_ARGS__)
================================================
FILE: gpt_oss/metal/source/include/internal/macros.h
================================================
#pragma once
/***** Architecture detection macros *****/
#ifdef GPTOSS_ARCH_X86_64
#if GPTOSS_ARCH_X86_64 != 0 && GPTOSS_ARCH_X86_64 != 1
#error "Invalid GPTOSS_ARCH_X86_64 value: must be either 0 or 1"
#endif
#else
#if defined(__x86_64__) || defined(_M_X64) && !defined(_M_ARM64EC)
#define GPTOSS_ARCH_X86_64 1
#else
#define GPTOSS_ARCH_X86_64 0
#endif
#endif
#ifdef GPTOSS_ARCH_ARM64
#if GPTOSS_ARCH_ARM64 != 0 && GPTOSS_ARCH_ARM64 != 1
#error "Invalid GPTOSS_ARCH_ARM64 value: must be either 0 or 1"
#endif
#else
#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)
#define GPTOSS_ARCH_ARM64 1
#else
#define GPTOSS_ARCH_ARM64 0
#endif
#endif
#if GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 == 0
#error "Unsupported architecture: neither x86-64 nor ARM64 detected"
#elif GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 != 1
#error "Inconsistent architecture detection: both x86-64 and ARM64 detection macros are specified"
#endif
/***** Compiler portability macros *****/
#ifndef GPTOSS_LIKELY
#if defined(__GNUC__)
#define GPTOSS_LIKELY(condition) (__builtin_expect(!!(condition), 1))
#else
#define GPTOSS_LIKELY(condition) (!!(condition))
#endif
#endif
#ifndef GPTOSS_UNLIKELY
#if defined(__GNUC__)
#define GPTOSS_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
#else
#define GPTOSS_UNLIKELY(condition) (!!(condition))
#endif
#endif
#ifndef GPTOSS_UNPREDICTABLE
#if defined(__has_builtin)
#if __has_builtin(__builtin_unpredictable)
#define GPTOSS_UNPREDICTABLE(condition) (__builtin_unpredictable(!!(condition)))
#endif
#endif
#endif
#ifndef GPTOSS_UNPREDICTABLE
#if defined(__GNUC__) && (__GNUC__ >= 9) && !defined(__INTEL_COMPILER)
#define GPTOSS_UNPREDICTABLE(condition) (__builtin_expect_with_probability(!!(condition), 0, 0.5))
#else
#define GPTOSS_UNPREDICTABLE(condition) (!!(condition))
#endif
#endif
// Disable padding for structure members.
#ifndef GPTOSS_DENSELY_PACKED_STRUCTURE
#if defined(__GNUC__)
#define GPTOSS_DENSELY_PACKED_STRUCTURE __attribute__((__packed__))
#else
#error "Compiler-specific implementation of GPTOSS_DENSELY_PACKED_STRUCTURE required"
#endif
#endif
#ifndef GPTOSS_ALIGN
#if defined(__GNUC__)
#define GPTOSS_ALIGN(alignment) __attribute__((__aligned__(alignment)))
#elif defined(_MSC_VER)
#define GPTOSS_ALIGN(alignment) __declspec(align(alignment))
#else
#error "Compiler-specific implementation of GPTOSS_ALIGN required"
#endif
#endif
#ifndef GPTOSS_FORCE_INLINE
#if defined(__GNUC__)
#define GPTOSS_FORCE_INLINE inline __attribute__((__always_inline__))
#elif defined(_MSC_VER)
#define GPTOSS_FORCE_INLINE __forceinline
#else
#define GPTOSS_FORCE_INLINE inline
#endif
#endif
/***** Symbol visibility macros *****/
#ifndef GPTOSS_INTERNAL_SYMBOL
#if defined(__ELF__)
#define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__("internal")))
#elif defined(__MACH__)
#define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__("hidden")))
#else
#define GPTOSS_INTERNAL_SYMBOL
#endif
#endif
================================================
FILE: gpt_oss/metal/source/include/internal/math.h
================================================
#pragma once
#include
#include
#include
inline static size_t math_ceil_div(size_t numer, size_t denom) {
return (numer + denom - 1) / denom;
}
inline static size_t math_max(size_t a, size_t b) {
return a >= b ? a : b;
}
inline static size_t math_min(size_t a, size_t b) {
return a < b ? a : b;
}
inline static size_t math_sub_sat(size_t a, size_t b) {
return a > b ? a - b : 0;
}
static size_t math_round_down_po2(size_t number, size_t multiple) {
assert(multiple != 0);
assert((multiple & (multiple - 1)) == 0);
return number & -multiple;
}
static size_t math_round_up_po2(size_t number, size_t multiple) {
assert(multiple != 0);
assert((multiple & (multiple - 1)) == 0);
const size_t multiple_mask = multiple - 1;
if ((number & multiple_mask) != 0) {
number |= multiple_mask;
number += 1;
}
return number;
}
================================================
FILE: gpt_oss/metal/source/include/internal/metal-kernels.h
================================================
#pragma once
#include
#include
#include
#ifdef __cplusplus
extern "C" {
#endif
#include
#include
#include
#include
#include
enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* u32_fill_random_fn,
size_t threadgroup_size,
size_t max_threadgroups,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
uint64_t num_elements,
uint64_t rng_seed,
uint64_t rng_offset);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_fill_random_fn,
size_t threadgroup_size,
size_t max_threadgroups,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
uint64_t num_elements,
uint64_t rng_seed,
uint64_t rng_offset,
float rng_min,
float rng_max);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* bf16_fill_random_fn,
size_t threadgroup_size,
size_t max_threadgroups,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
uint64_t num_elements,
uint64_t rng_seed,
uint64_t rng_offset,
float rng_min,
float rng_max);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* mf4_f32_convert_fn,
size_t threadgroup_size,
size_t max_threadgroups,
const struct gptoss_metal_buffer* block_buffer,
const struct gptoss_metal_buffer* scale_buffer,
const struct gptoss_metal_buffer* output_buffer,
uint64_t num_elements);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* bf16_f32_embeddings_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* token_buffer,
size_t token_offset,
const struct gptoss_metal_buffer* weight_buffer,
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_bf16w_rmsnorm_fn,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_buffer,
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels,
float epsilon);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_buffer,
size_t weight_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_buffer,
size_t weight_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* kv_buffer,
size_t kv_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_q_heads,
uint32_t num_kv_heads,
uint32_t attn_head_dim,
uint32_t token_offset,
uint32_t max_tokens,
float rope_base,
float interpolation_scale,
float yarn_offset,
float yarn_scale,
float yarn_multiplier);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_buffer,
size_t weight_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
enum gptoss_status
gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_buffer,
size_t weight_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* kv_buffer,
size_t kv_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows,
uint32_t max_tokens,
uint32_t token_offset);
enum gptoss_status
gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_buffer,
size_t weight_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
enum gptoss_status
gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_buffer,
size_t weight_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
size_t threadgroup_size,
size_t max_threadgroups,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_buffer,
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* argmax_buffer,
size_t argmax_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_mf4w_moe_matmul_swiglu_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* expert_buffer,
size_t expert_offset,
const struct gptoss_metal_buffer* weight_block_buffer,
size_t weight_block_offset,
const struct gptoss_metal_buffer* weight_scale_buffer,
size_t weight_scale_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
float swiglu_limit,
uint32_t expert_stride,
uint32_t num_tokens,
uint32_t num_active_experts,
uint32_t num_cols,
uint32_t num_rows);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_mf4w_moe_matmul_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* expert_buffer,
size_t expert_offset,
const struct gptoss_metal_buffer* weight_block_buffer,
size_t weight_block_offset,
const struct gptoss_metal_buffer* weight_scale_buffer,
size_t weight_scale_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t expert_stride,
uint32_t num_tokens,
uint32_t num_active_experts,
uint32_t num_cols,
uint32_t num_rows);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_rope_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* activations_buffer,
size_t activations_offset,
const struct gptoss_metal_buffer* kv_buffer,
size_t kv_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
float rope_base,
float interpolation_scale,
float yarn_offset,
float yarn_scale,
float yarn_multiplier,
uint32_t num_tokens,
uint32_t num_q_heads,
uint32_t num_kv_heads,
uint32_t attn_head_dim,
uint32_t max_tokens,
uint32_t token_offset);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_accumulate_fn,
size_t threadgroup_size,
size_t max_threadgroups,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* expert_buffer,
size_t expert_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
uint32_t num_experts);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* expert_routing_metadata_fn,
const struct gptoss_metal_buffer* expert_predictions_buffer,
size_t expert_predictions_offset,
const struct gptoss_metal_buffer* expert_offsets_buffer,
size_t expert_offsets_offset,
const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
size_t intra_expert_offsets_offset,
uint32_t num_tokens,
uint32_t num_experts);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_scatter(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_scatter_fn,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* expert_predictions_buffer,
size_t expert_predictions_offset,
const struct gptoss_metal_buffer* expert_offsets_buffer,
size_t expert_offsets_offset,
const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
size_t intra_expert_offsets_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
uint32_t num_channels,
uint32_t num_tokens,
uint32_t num_active_experts);
enum gptoss_status
gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_swiglu_fn,
const struct gptoss_metal_buffer* expert_offsets_buffer,
size_t expert_offsets_offset,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_block_buffer,
size_t weight_block_offset,
const struct gptoss_metal_buffer* weight_scale_buffer,
size_t weight_scale_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
float swiglu_limit,
uint32_t expert_stride_bytes,
uint32_t num_tokens,
uint32_t num_experts,
uint32_t num_cols,
uint32_t num_rows);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_fn,
const struct gptoss_metal_buffer* expert_offsets_buffer,
size_t expert_offsets_offset,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* weight_block_buffer,
size_t weight_block_offset,
const struct gptoss_metal_buffer* weight_scale_buffer,
size_t weight_scale_offset,
const struct gptoss_metal_buffer* bias_buffer,
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
uint32_t expert_stride_bytes,
uint32_t num_tokens,
uint32_t num_experts,
uint32_t num_cols,
uint32_t num_rows);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_gather_and_accumulate_e4_fn,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* expert_predictions_buffer,
size_t expert_predictions_offset,
const struct gptoss_metal_buffer* expert_offsets_buffer,
size_t expert_offsets_offset,
const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
size_t intra_expert_offsets_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
uint32_t num_channels,
uint32_t num_tokens,
uint32_t num_active_experts);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_topk_fn,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_tokens,
uint32_t num_experts,
uint32_t num_active_experts);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_sdpa_fn,
const struct gptoss_metal_buffer* q_buffer,
size_t q_offset,
const struct gptoss_metal_buffer* kv_buffer,
size_t kv_offset,
const struct gptoss_metal_buffer* s_buffer,
size_t s_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t window,
uint32_t kv_stride,
uint32_t num_q_tokens,
uint32_t num_kv_tokens,
uint32_t num_q_heads,
uint32_t num_kv_heads,
uint32_t head_dim);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_softmax_fn,
size_t threadgroup_size,
size_t max_threadgroups,
const struct gptoss_metal_buffer* score_buffer,
size_t score_offset,
const struct gptoss_metal_buffer* argmax_buffer,
size_t argmax_offset,
const struct gptoss_metal_buffer* prob_buffer,
size_t prob_offset,
const struct gptoss_metal_buffer* sum_buffer,
size_t sum_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
float temperature,
uint32_t* num_threadgroups_out,
uint32_t* num_channels_per_threadgroup_out);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_sample_fn,
size_t min_threadgroup_size,
const struct gptoss_metal_buffer* prob_buffer,
size_t prob_offset,
const struct gptoss_metal_buffer* sum_buffer,
size_t sum_offset,
const struct gptoss_metal_buffer* token_buffer,
size_t token_offset,
const struct gptoss_metal_buffer* control_buffer,
size_t control_offset,
uint64_t rng_seed,
uint32_t rng_offset,
uint32_t num_blocks,
uint32_t num_channels,
uint32_t num_channels_per_block);
#ifdef __cplusplus
} // extern "C"
#endif
================================================
FILE: gpt_oss/metal/source/include/internal/metal.h
================================================
#pragma once
#include
#include
#ifdef __cplusplus
extern "C" {
#endif
struct gptoss_metal_device {
void* object; // id
size_t num_cores;
size_t max_buffer_size;
size_t max_threadgroup_memory;
size_t max_threadgroup_threads_x;
size_t max_threadgroup_threads_y;
size_t max_threadgroup_threads_z;
};
enum gptoss_status gptoss_metal_device_create_system_default(
struct gptoss_metal_device* device_out);
enum gptoss_status gptoss_metal_device_release(
struct gptoss_metal_device* device);
struct gptoss_metal_library {
void* object; // id
};
enum gptoss_status gptoss_metal_library_create_default(
const struct gptoss_metal_device* device,
struct gptoss_metal_library* library_out);
enum gptoss_status gptoss_metal_library_release(
struct gptoss_metal_library* library);
struct gptoss_metal_function {
void* function_object; // id
void* pipeline_state_object; // id
size_t max_threadgroup_threads;
size_t simdgroup_threads;
size_t static_threadgroup_memory;
};
enum gptoss_status gptoss_metal_function_create(
const struct gptoss_metal_library* library,
const char* name,
struct gptoss_metal_function* function_out);
enum gptoss_status gptoss_metal_function_release(
struct gptoss_metal_function* function);
struct gptoss_metal_buffer {
void* object; // id
size_t size;
void* ptr;
};
enum gptoss_status gptoss_metal_buffer_create(
const struct gptoss_metal_device* device,
size_t size,
const void* data,
struct gptoss_metal_buffer* buffer_out);
enum gptoss_status gptoss_metal_buffer_wrap(
const struct gptoss_metal_device* device,
size_t size,
const void* data,
struct gptoss_metal_buffer* buffer_out);
enum gptoss_status gptoss_metal_buffer_release(
struct gptoss_metal_buffer* buffer);
struct gptoss_metal_command_queue {
void* object; // id
};
enum gptoss_status gptoss_metal_command_queue_create(
const struct gptoss_metal_device* device,
struct gptoss_metal_command_queue* command_queue_out);
enum gptoss_status gptoss_metal_command_queue_release(
struct gptoss_metal_command_queue* command_queue);
struct gptoss_metal_command_buffer {
void* object; // id
};
enum gptoss_status gptoss_metal_command_buffer_create(
const struct gptoss_metal_command_queue* command_queue,
struct gptoss_metal_command_buffer* command_buffer_out);
enum gptoss_status gptoss_metal_command_buffer_encode_fill_buffer(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_buffer* buffer,
size_t offset,
size_t size,
uint8_t fill_value);
enum gptoss_status gptoss_metal_command_buffer_encode_copy_buffer(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_buffer* input_buffer,
size_t input_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
size_t size);
enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* function,
size_t threadgroup_size_x,
size_t threadgroup_size_y,
size_t threadgroup_size_z,
size_t num_threadgroups_x,
size_t num_threadgroups_y,
size_t num_threadgroups_z,
size_t params_size,
const void* params,
size_t num_device_buffers,
const struct gptoss_metal_buffer** device_buffers,
const size_t* device_buffer_offsets,
size_t threadgroup_buffer_size);
enum gptoss_status gptoss_metal_command_buffer_commit(
const struct gptoss_metal_command_buffer* command_buffer);
enum gptoss_status gptoss_metal_command_buffer_wait_completion(
const struct gptoss_metal_command_buffer* command_buffer,
double* elapsed_seconds);
enum gptoss_status gptoss_metal_command_buffer_release(
struct gptoss_metal_command_buffer* command_buffer);
#ifdef __cplusplus
} // extern "C"
#endif
================================================
FILE: gpt_oss/metal/source/include/internal/metal.hpp
================================================
#pragma once
#include
#include
#include
#include
#include
#include
#include
#include
namespace gptoss {
inline void Check(gptoss_status s, const char* what) {
if (s != gptoss_status_success) {
throw std::runtime_error(what);
}
}
inline std::size_t round_up(std::size_t p, std::size_t q) {
const std::size_t r = p % q;
if (r == 0) {
return p;
} else {
return p - r + q;
}
}
namespace metal {
class Device {
public:
inline Device() {
Check(gptoss_metal_device_create_system_default(&device_), "create Device");
}
inline ~Device() {
gptoss_metal_device_release(&device_);
}
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
inline Device(Device&& other) noexcept {
device_ = other.device_;
std::memset(&other.device_, 0, sizeof(other.device_));
}
inline Device& operator=(Device&& other) noexcept {
if (this != &other) {
gptoss_metal_device_release(&device_);
device_ = other.device_;
std::memset(&other.device_, 0, sizeof(other.device_));
}
return *this;
}
inline const gptoss_metal_device* handle() const noexcept { return &device_; }
inline size_t max_buffer_size() const noexcept { return device_.max_buffer_size; }
inline size_t max_threadgroup_memory() const noexcept { return device_.max_threadgroup_memory; }
inline size_t max_threadgroup_threads_x() const noexcept { return device_.max_threadgroup_threads_x; }
inline size_t max_threadgroup_threads_y() const noexcept { return device_.max_threadgroup_threads_y; }
inline size_t max_threadgroup_threads_z() const noexcept { return device_.max_threadgroup_threads_z; }
private:
gptoss_metal_device device_{};
};
class Library {
public:
inline explicit Library(const Device& dev) {
Check(gptoss_metal_library_create_default(dev.handle(), &library_),
"gptoss_metal_library_create_default");
}
inline ~Library() {
gptoss_metal_library_release(&library_);
}
Library(const Library&) = delete;
Library& operator=(const Library&) = delete;
inline Library(Library&& other) noexcept {
library_ = other.library_;
std::memset(&other.library_, 0, sizeof(other.library_));
}
inline Library& operator=(Library&& other) noexcept {
if (this != &other) {
gptoss_metal_library_release(&library_);
library_ = other.library_;
std::memset(&other.library_, 0, sizeof(other.library_));
}
return *this;
}
inline const gptoss_metal_library* handle() const noexcept {
return &library_;
}
private:
gptoss_metal_library library_{};
};
class Function {
public:
inline Function(const Library& library, const char* name) {
Check(gptoss_metal_function_create(library.handle(), name, &function_),
"gptoss_metal_function_create");
}
inline ~Function() {
gptoss_metal_function_release(&function_);
}
Function(const Function&) = delete;
Function& operator=(const Function&) = delete;
inline Function(Function&& other) noexcept {
function_ = other.function_;
std::memset(&other.function_, 0, sizeof(other.function_));
}
inline Function& operator=(Function&& other) noexcept {
if (this != &other) {
gptoss_metal_function_release(&function_);
function_ = other.function_;
std::memset(&other.function_, 0, sizeof(other.function_));
}
return *this;
}
inline const gptoss_metal_function* handle() const noexcept { return &function_; }
inline size_t max_threadgroup_threads() const noexcept { return function_.max_threadgroup_threads; }
inline size_t simdgroup_threads() const noexcept { return function_.simdgroup_threads; }
inline size_t static_threadgroup_memory() const noexcept { return function_.static_threadgroup_memory; }
private:
gptoss_metal_function function_{};
};
class Buffer {
public:
inline Buffer(const Device& dev, size_t size, const void* data = nullptr) {
Check(gptoss_metal_buffer_create(dev.handle(), size, data, &buffer_), "create buffer");
}
inline ~Buffer() {
gptoss_metal_buffer_release(&buffer_);
}
Buffer(const Buffer&) = delete;
Buffer& operator=(const Buffer&) = delete;
inline Buffer(Buffer&& other) noexcept {
buffer_ = other.buffer_;
std::memset(&other.buffer_, 0, sizeof(other.buffer_));
}
inline Buffer& operator=(Buffer&& other) noexcept {
if (this != &other) {
gptoss_metal_buffer_release(&buffer_);
buffer_ = other.buffer_;
std::memset(&other.buffer_, 0, sizeof(other.buffer_));
}
return *this;
}
inline size_t size() const noexcept { return buffer_.size; }
inline void* ptr() const noexcept { return buffer_.ptr; }
inline const gptoss_metal_buffer* handle() const noexcept { return &buffer_; }
private:
gptoss_metal_buffer buffer_{};
};
class CommandQueue {
public:
inline explicit CommandQueue(const Device& dev) {
Check(gptoss_metal_command_queue_create(dev.handle(), &command_queue_),
"gptoss_metal_command_queue_create");
}
inline ~CommandQueue() {
gptoss_metal_command_queue_release(&command_queue_);
}
CommandQueue(const CommandQueue&) = delete;
CommandQueue& operator=(const CommandQueue&) = delete;
inline CommandQueue(CommandQueue&& other) noexcept {
command_queue_ = other.command_queue_;
std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));
}
inline CommandQueue& operator=(CommandQueue&& other) noexcept {
if (this != &other) {
gptoss_metal_command_queue_release(&command_queue_);
command_queue_ = other.command_queue_;
std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));
}
return *this;
}
inline const gptoss_metal_command_queue* handle() const noexcept {
return &command_queue_;
}
private:
gptoss_metal_command_queue command_queue_{};
};
class CommandBuffer {
public:
inline explicit CommandBuffer(const CommandQueue& command_queue) {
Check(gptoss_metal_command_buffer_create(command_queue.handle(), &command_buffer_),
"gptoss_metal_command_buffer_create");
}
inline ~CommandBuffer() {
gptoss_metal_command_buffer_release(&command_buffer_);
}
CommandBuffer(const CommandBuffer&) = delete;
CommandBuffer& operator=(const CommandBuffer&) = delete;
inline CommandBuffer(CommandBuffer&& other) noexcept {
command_buffer_ = other.command_buffer_;
std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));
}
inline CommandBuffer& operator=(CommandBuffer&& other) noexcept {
if (this != &other) {
gptoss_metal_command_buffer_release(&command_buffer_);
command_buffer_ = other.command_buffer_;
std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));
}
return *this;
}
inline void encode_launch_kernel(const Function& function,
const std::array& threadgroup_size,
const std::array& num_threadgroups,
size_t params_size, const void* params,
std::initializer_list device_buffers = {},
size_t threadgroup_buffer_size = 0)
{
std::vector buffer_handles(device_buffers.size());
std::transform(device_buffers.begin(), device_buffers.end(), buffer_handles.begin(),
[](const Buffer* buffer) -> const gptoss_metal_buffer* { return buffer->handle(); });
Check(gptoss_metal_command_buffer_encode_launch_kernel(
&command_buffer_, function.handle(),
threadgroup_size[0], threadgroup_size[1], threadgroup_size[2],
num_threadgroups[0], num_threadgroups[1], num_threadgroups[2],
params_size, params,
buffer_handles.size(),
buffer_handles.data(),
/*buffer_offsets=*/nullptr,
threadgroup_buffer_size),
"gptoss_metal_command_buffer_encode_launch_kernel");
}
inline void encode_launch_f32_fill_random(const Function& f32_fill_random_fn,
size_t threadgroup_size,
size_t num_threadgroups,
const Buffer& output_buffer,
size_t output_offset,
size_t num_channels,
uint64_t rng_seed,
uint64_t rng_offset,
float rng_min,
float rng_max)
{
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
&command_buffer_, f32_fill_random_fn.handle(),
threadgroup_size, num_threadgroups,
output_buffer.handle(), output_offset,
num_channels,
rng_seed, rng_offset, rng_min, rng_max),
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
}
inline void encode_launch_bf16_fill_random(const Function& bf16_fill_random_fn,
size_t threadgroup_size,
size_t num_threadgroups,
const Buffer& output_buffer,
size_t output_offset,
size_t num_channels,
uint64_t rng_seed,
uint64_t rng_offset,
float rng_min,
float rng_max)
{
Check(gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
&command_buffer_, bf16_fill_random_fn.handle(),
threadgroup_size, num_threadgroups,
output_buffer.handle(), output_offset,
num_channels,
rng_seed, rng_offset, rng_min, rng_max),
"gptoss_metal_command_buffer_encode_launch_bf16_fill_random");
}
inline void encode_launch_u32_fill_random(const Function& u32_fill_random_fn,
size_t threadgroup_size,
size_t num_threadgroups,
const Buffer& output_buffer,
size_t output_offset,
size_t num_channels,
uint64_t rng_seed,
uint64_t rng_offset)
{
Check(gptoss_metal_command_buffer_encode_launch_u32_fill_random(
&command_buffer_, u32_fill_random_fn.handle(),
threadgroup_size, num_threadgroups,
output_buffer.handle(), output_offset,
num_channels,
rng_seed, rng_offset),
"gptoss_metal_command_buffer_encode_launch_u32_fill_random");
}
inline void commit() {
Check(gptoss_metal_command_buffer_commit(&command_buffer_), "commit");
}
inline double wait_completion() {
double secs = 0.0;
Check(gptoss_metal_command_buffer_wait_completion(&command_buffer_, &secs), "wait completion");
return secs;
}
inline const gptoss_metal_command_buffer* handle() const noexcept { return &command_buffer_; }
private:
gptoss_metal_command_buffer command_buffer_{};
};
} // namespace metal
} // namespace gptoss
================================================
FILE: gpt_oss/metal/source/include/internal/model.h
================================================
#pragma once
#ifndef __cplusplus
#include
#endif
#include