[
  {
    "path": ".github/CODEOWNERS",
    "content": "@openai/developer-experience\ndkundel-openai\nMaratyszcza\nscott-oai\nvolsgd\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: false\ncontact_links:\n  - name: 🐛 Model Issues\n    url: https://huggingface.co/openai/gpt-oss-120b/discussions\n    about: For general questions about the models, please use the Community feature on Hugging Face.\n  - name: 💡 General Feedback\n    url: https://openai.com/open-models\n    about: Suggest new features on our feature request page.\n"
  },
  {
    "path": ".github/workflows/CI.yml",
    "content": "name: CI\n\non:\n  release:\n    types: [published]\n  push:\n    tags:\n      - \"v*\"\n  workflow_dispatch:\n\n# Minimal repo-level permissions; job-level permissions override where needed.\npermissions:\n  contents: read\n  id-token: write\n\njobs:\n  publish:\n    name: Build & Publish to PyPI (Trusted Publishing)\n    runs-on: ubuntu-latest\n\n    # Run in the GitHub environment named \"release\" so you can gate it with approvals.\n    environment: release\n\n    # Extra permissions required for pypa action to do OIDC exchange:\n    permissions:\n      contents: read\n      id-token: write\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: \"3.12\"\n\n      - name: Install build tools\n        run: |\n          python -m pip install --upgrade pip setuptools wheel build\n\n      - name: Install uv (if needed)\n        run: |\n          python -m pip install --upgrade uv || true\n\n      - name: Build package with uv\n        run: |\n          pwd\n          ls -la\n          uv build\n\n      - name: Inspect dist folder\n        run: |\n          ls -la dist || ls -la build || echo \"no dist/ or build/ — check uv output\"\n\n      - name: Publish to PyPI using Trusted Publishing\n        # Note: No pypi_token / username / password provided — Trusted Publishing via OIDC is used.\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          attestations: true # optional (default for Trusted Publishing) - set to false to disable\n"
  },
  {
    "path": ".gitignore",
    "content": "build\n_skbuild\ntmp*\n__pycache__\n*.egg*\nnode_modules/\n*.log"
  },
  {
    "path": "CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.26)\nproject(gpt_oss LANGUAGES C CXX)\n\n# If not defined externally, auto-detect\nif(NOT DEFINED GPTOSS_BUILD_METAL)\n  if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES \"arm64\")\n    message(STATUS \"Apple Silicon detected → enabling GPTOSS_BUILD_METAL\")\n    set(GPTOSS_BUILD_METAL ON)\n  else()\n    message(STATUS \"Non-Apple Silicon → disabling GPTOSS_BUILD_METAL\")\n    set(GPTOSS_BUILD_METAL OFF)\n  endif()\nelse()\n  message(STATUS \"GPTOSS_BUILD_METAL manually set to: ${GPTOSS_BUILD_METAL}\")\nendif()\n\n# Now declare it as a cache variable (respects user-provided value)\nset(GPTOSS_BUILD_METAL \"${GPTOSS_BUILD_METAL}\" CACHE BOOL \"Enable Metal backend\")\n\nif(GPTOSS_BUILD_METAL)\n  enable_language(OBJC)\n  add_subdirectory(gpt_oss/metal)\nendif()\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "recursive-include _build * "
  },
  {
    "path": "README.md",
    "content": "<img alt=\"gpt-oss-120\" src=\"./docs/gpt-oss.svg\">\n<p align=\"center\">\n  <a href=\"https://gpt-oss.com\"><strong>Try gpt-oss</strong></a> ·\n  <a href=\"https://cookbook.openai.com/topic/gpt-oss\"><strong>Guides</strong></a> ·\n  <a href=\"https://arxiv.org/abs/2508.10925\"><strong>Model card</strong></a> ·\n  <a href=\"https://openai.com/index/introducing-gpt-oss/\"><strong>OpenAI blog</strong></a>\n</p>\n<p align=\"center\">\n  <strong>Download <a href=\"https://huggingface.co/openai/gpt-oss-120b\">gpt-oss-120b</a> and <a href=\"https://huggingface.co/openai/gpt-oss-20b\">gpt-oss-20b</a> on Hugging Face</strong>\n</p>\n\n<br>\n\nWelcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases.\n\nWe're releasing two flavors of these open models:\n\n- `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fit into a single 80GB GPU (like NVIDIA H100 or AMD MI300X) (117B parameters with 5.1B active parameters)\n- `gpt-oss-20b` — for lower latency, and local or specialized use cases (21B parameters with 3.6B active parameters)\n\nBoth models were trained using our [harmony response format][harmony] and should only be used with this format; otherwise, they will not work correctly.\n\n## Table of Contents\n- [Highlights](#highlights)\n- [Inference examples](#inference-examples)\n- [About this repository](#about-this-repository)\n- [Setup](#setup)\n- [Download the model](#download-the-model)\n- [Reference PyTorch implementation](#reference-pytorch-implementation)\n- [Reference Triton implementation (single GPU)](#reference-triton-implementation-single-gpu)\n- [Reference Metal implementation](#reference-metal-implementation)\n- [Harmony format & tools](#harmony-format--tools)\n- [Clients](#clients)\n- [Tools](#tools)\n- [Other details](#other-details)\n- [Contributing](#contributing)\n\n### Highlights\n\n- **Permissive Apache 2.0 license:** Build freely without copyleft restrictions or patent risk—ideal for experimentation, customization, and commercial deployment.\n- **Configurable reasoning effort:** Easily adjust the reasoning effort (low, medium, high) based on your specific use case and latency needs.\n- **Full chain-of-thought:** Provides complete access to the model's reasoning process, facilitating easier debugging and greater trust in outputs. This information is not intended to be shown to end users.\n- **Fine-tunable:** Fully customize models to your specific use case through parameter fine-tuning.\n- **Agentic capabilities:** Use the models' native capabilities for function calling, [web browsing](#browser), [Python code execution](#python), and Structured Outputs.\n- **MXFP4 quantization:** The models were post-trained with MXFP4 quantization of the MoE weights, making `gpt-oss-120b` run on a single 80GB GPU (like NVIDIA H100 or AMD MI300X) and the `gpt-oss-20b` model run within 16GB of memory. All evals were performed with the same MXFP4 quantization.\n\n### Inference examples\n\n#### Transformers\n\nYou can use `gpt-oss-120b` and `gpt-oss-20b` with the Transformers library. If you use Transformers' chat template, it will automatically apply the [harmony response format][harmony]. If you use `model.generate` directly, you need to apply the harmony format manually using the chat template or use our [`openai-harmony`][harmony] package.\n\n```python\nfrom transformers import pipeline\nimport torch\n\nmodel_id = \"openai/gpt-oss-120b\"\n\npipe = pipeline(\n    \"text-generation\",\n    model=model_id,\n    torch_dtype=\"auto\",\n    device_map=\"auto\",\n)\n\nmessages = [\n    {\"role\": \"user\", \"content\": \"Explain quantum mechanics clearly and concisely.\"},\n]\n\noutputs = pipe(\n    messages,\n    max_new_tokens=256,\n)\nprint(outputs[0][\"generated_text\"][-1])\n```\n\n[Learn more about how to use gpt-oss with Transformers.](https://cookbook.openai.com/articles/gpt-oss/run-transformers)\n\n#### vLLM\n\nvLLM recommends using [`uv`](https://docs.astral.sh/uv/) for Python dependency management. You can use vLLM to spin up an OpenAI-compatible web server. The following command will automatically download the model and start the server.\n\n```bash\nuv pip install --pre vllm==0.10.1+gptoss \\\n    --extra-index-url https://wheels.vllm.ai/gpt-oss/ \\\n    --extra-index-url https://download.pytorch.org/whl/nightly/cu128 \\\n    --index-strategy unsafe-best-match\n\nvllm serve openai/gpt-oss-20b\n```\n\n[Learn more about how to use gpt-oss with vLLM.](https://cookbook.openai.com/articles/gpt-oss/run-vllm)\n\nOffline Serve Code:\n- run this code after installing proper libraries as described, while additionally installing this:\n- `uv pip install openai-harmony`\n```python\n# source .oss/bin/activate\n\nimport os\nos.environ[\"VLLM_USE_FLASHINFER_SAMPLER\"] = \"0\"\n\nimport json\nfrom openai_harmony import (\n    HarmonyEncodingName,\n    load_harmony_encoding,\n    Conversation,\n    Message,\n    Role,\n    SystemContent,\n    DeveloperContent,\n)\n \nfrom vllm import LLM, SamplingParams\nimport os\n\n# --- 1) Render the prefill with Harmony ---\nencoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n \nconvo = Conversation.from_messages(\n    [\n        Message.from_role_and_content(Role.SYSTEM, SystemContent.new()),\n        Message.from_role_and_content(\n            Role.DEVELOPER,\n            DeveloperContent.new().with_instructions(\"Always respond in riddles\"),\n        ),\n        Message.from_role_and_content(Role.USER, \"What is the weather like in SF?\"),\n    ]\n)\n \nprefill_ids = encoding.render_conversation_for_completion(convo, Role.ASSISTANT)\n \n# Harmony stop tokens (pass to sampler so they won't be included in output)\nstop_token_ids = encoding.stop_tokens_for_assistant_actions()\n \n# --- 2) Run vLLM with prefill ---\nllm = LLM(\n    model=\"openai/gpt-oss-20b\",\n    trust_remote_code=True,\n    gpu_memory_utilization = 0.95,\n    max_num_batched_tokens=4096,\n    max_model_len=5000,\n    tensor_parallel_size=1\n)\n \nsampling = SamplingParams(\n    max_tokens=128,\n    temperature=1,\n    stop_token_ids=stop_token_ids,\n)\n \noutputs = llm.generate(\n    prompt_token_ids=[prefill_ids],   # batch of size 1\n    sampling_params=sampling,\n)\n \n# vLLM gives you both text and token IDs\ngen = outputs[0].outputs[0]\ntext = gen.text\noutput_tokens = gen.token_ids  # <-- these are the completion token IDs (no prefill)\n \n# --- 3) Parse the completion token IDs back into structured Harmony messages ---\nentries = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT)\n \n# 'entries' is a sequence of structured conversation entries (assistant messages, tool calls, etc.).\nfor message in entries:\n    print(f\"{json.dumps(message.to_dict())}\")\n```\n\n#### PyTorch / Triton / Metal\n\nThese implementations are largely reference implementations for educational purposes and are not expected to be run in production.\n\n[Learn more below.](#reference-pytorch-implementation)\n\n#### Ollama\n\nIf you are trying to run `gpt-oss` on consumer hardware, you can use Ollama by running the following commands after [installing Ollama](https://ollama.com/download).\n\n```bash\n# gpt-oss-20b\nollama pull gpt-oss:20b\nollama run gpt-oss:20b\n\n# gpt-oss-120b\nollama pull gpt-oss:120b\nollama run gpt-oss:120b\n```\n\n[Learn more about how to use gpt-oss with Ollama.](https://cookbook.openai.com/articles/gpt-oss/run-locally-ollama)\n\n#### LM Studio\n\nIf you are using [LM Studio](https://lmstudio.ai/) you can use the following commands to download.\n\n```bash\n# gpt-oss-20b\nlms get openai/gpt-oss-20b\n# gpt-oss-120b\nlms get openai/gpt-oss-120b\n```\n\nCheck out our [awesome list](./awesome-gpt-oss.md) for a broader collection of gpt-oss resources and inference partners.\n\n## About this repository\n\nThis repository provides a collection of reference implementations:\n\n- **Inference:**\n  - [`torch`](#reference-pytorch-implementation) — a non-optimized [PyTorch](https://pytorch.org/) implementation for educational purposes only. Requires at least 4× H100 GPUs due to lack of optimization.\n  - [`triton`](#reference-triton-implementation-single-gpu) — a more optimized implementation using [PyTorch](https://pytorch.org/) & [Triton](https://github.com/triton-lang/triton) incl. using CUDA graphs and basic caching\n  - [`metal`](#reference-metal-implementation) — a Metal-specific implementation for running the models on Apple Silicon hardware\n- **Tools:**\n  - [`browser`](#browser) — a reference implementation of the browser tool the models got trained on\n  - [`python`](#python) — a stateless reference implementation of the python tool the model got trained on\n- **Client examples:**\n  - [`chat`](#terminal-chat) — a basic terminal chat application that uses the PyTorch or Triton implementations for inference along with the python and browser tools\n  - [`responses_api`](#responses-api) — an example Responses API compatible server that implements the browser tool along with other Responses-compatible functionality\n\n## Setup\n\n### Requirements\n\n- Python 3.12\n- On macOS: Install the Xcode CLI tools --> `xcode-select --install`\n- On Linux: These reference implementations require CUDA\n- On Windows: These reference implementations have not been tested on Windows. Try using solutions like Ollama if you are trying to run the model locally.\n\n### Installation\n\nIf you want to try any of the code you can install it directly from [PyPI](https://pypi.org/project/gpt-oss/)\n\n```shell\n# if you just need the tools\npip install gpt-oss\n# if you want to try the torch implementation\npip install gpt-oss[torch]\n# if you want to try the triton implementation\npip install gpt-oss[triton]\n```\n\nIf you want to modify the code or try the metal implementation set the project up locally:\n\n```shell\ngit clone https://github.com/openai/gpt-oss.git\nGPTOSS_BUILD_METAL=1 pip install -e \".[metal]\"\n```\n\n## Download the model\n\nYou can download the model weights from the [Hugging Face Hub](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) directly from Hugging Face CLI:\n\n```shell\n# gpt-oss-120b\nhf download openai/gpt-oss-120b --include \"original/*\" --local-dir gpt-oss-120b/\n\n# gpt-oss-20b\nhf download openai/gpt-oss-20b --include \"original/*\" --local-dir gpt-oss-20b/\n```\n\n## Reference PyTorch implementation\n\nWe include an inefficient reference PyTorch implementation in [gpt_oss/torch/model.py](gpt_oss/torch/model.py). This code uses basic PyTorch operators to show the exact model architecture, with a small addition of supporting tensor parallelism in MoE so that the larger model can run with this code (e.g., on 4xH100 or 2xH200). In this implementation, we upcast all weights to BF16 and run the model in BF16.\n\nTo run the reference implementation, install the dependencies:\n\n```shell\npip install -e \".[torch]\"\n```\n\nAnd then run:\n\n```shell\n# On 4xH100:\ntorchrun --nproc-per-node=4 -m gpt_oss.generate gpt-oss-120b/original/\n```\n\n## Reference Triton implementation (single GPU)\n\nWe also include an optimized reference implementation that uses [an optimized triton MoE kernel](https://github.com/triton-lang/triton/tree/main/python/triton_kernels/triton_kernels) that supports MXFP4. It also has some optimization on the attention code to reduce the memory cost. To run this implementation, the nightly version of triton and torch will be installed. This version can be run on a single 80GB GPU for `gpt-oss-120b`.\n\nTo install the reference Triton implementation run\n\n```shell\n# You need to install triton from source to use the triton implementation\ngit clone https://github.com/triton-lang/triton\ncd triton/\npip install -r python/requirements.txt\npip install -e . --verbose --no-build-isolation\npip install -e python/triton_kernels\n\n# Install the gpt-oss triton implementation\npip install -e \".[triton]\"\n```\n\nAnd then run:\n\n```shell\n# On 1xH100\nexport PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True\npython -m gpt_oss.generate --backend triton gpt-oss-120b/original/\n```\n\nIf you encounter `torch.OutOfMemoryError`, make sure to turn on the expandable allocator to avoid crashes when loading weights from the checkpoint.\n\n## Reference Metal implementation\n\nAdditionally we are providing a reference implementation for Metal to run on Apple Silicon. This implementation is not production-ready but is accurate to the PyTorch implementation.\n\nThe implementation will get automatically compiled when running the `.[metal]` installation on an Apple Silicon device:\n\n```shell\nGPTOSS_BUILD_METAL=1 pip install -e \".[metal]\"\n```\n\nTo perform inference you'll need to first convert the SafeTensor weights from Hugging Face into the right format using:\n\n```shell\npython gpt_oss/metal/scripts/create-local-model.py -s <model_dir> -d <output_file>\n```\n\nOr download the pre-converted weights:\n\n```shell\nhf download openai/gpt-oss-120b --include \"metal/*\" --local-dir gpt-oss-120b/metal/\nhf download openai/gpt-oss-20b --include \"metal/*\" --local-dir gpt-oss-20b/metal/\n```\n\nTo test it you can run:\n\n```shell\npython gpt_oss/metal/examples/generate.py gpt-oss-20b/metal/model.bin -p \"why did the chicken cross the road?\"\n```\n\n## Harmony format & tools\n\nAlong with the model, we are also releasing a new chat format library `harmony` to interact with the model. Check [this guide](https://cookbook.openai.com/articles/openai-harmony) for more info about harmony.\n\nWe also include two system tools for the model: browsing and python container. Check [gpt_oss/tools](gpt_oss/tools) for the tool implementation.\n\n## Clients\n\n### Terminal Chat\n\nThe terminal chat application is a basic example of how to use the harmony format together with the PyTorch, Triton, and vLLM implementations. It also exposes both the python and browser tool as optional tools that can be used.\n\n```bash\nusage: python -m gpt_oss.chat [-h] [-r REASONING_EFFORT] [-a] [-b] [--show-browser-results] [-p] [--developer-message DEVELOPER_MESSAGE] [-c CONTEXT] [--raw] [--backend {triton,torch,vllm}] FILE\n\nChat example\n\npositional arguments:\n  FILE                  Path to the SafeTensors checkpoint\n\noptions:\n  -h, --help            show this help message and exit\n  -r REASONING_EFFORT, --reasoning-effort REASONING_EFFORT\n                        Reasoning effort (default: low)\n  -a, --apply-patch     Make apply_patch tool available to the model (default: False)\n  -b, --browser         Use browser tool (default: False)\n  --show-browser-results\n                        Show browser results (default: False)\n  -p, --python          Use python tool (default: False)\n  --developer-message DEVELOPER_MESSAGE\n                        Developer message (default: )\n  -c CONTEXT, --context CONTEXT\n                        Max context length (default: 8192)\n  --raw                 Raw mode (does not render Harmony encoding) (default: False)\n  --backend {triton,torch,vllm}\n                        Inference backend (default: triton)\n```\n\n> [!NOTE]\n> The torch and triton implementations require original checkpoint under `gpt-oss-120b/original/` and `gpt-oss-20b/original/` respectively. While vLLM uses the Hugging Face converted checkpoint under `gpt-oss-120b/` and `gpt-oss-20b/` root directory respectively.\n\n### Responses API\n\nWe also include an example Responses API server. This server does not implement every feature and event of the Responses API but should be compatible with most of the basic use cases and serve as inspiration for anyone building their own server. Some of our inference partners are also offering their own Responses API.\n\nYou can start this server with the following inference backends:\n\n- `triton` — uses the triton implementation\n- `metal` — uses the metal implementation on Apple Silicon only\n- `ollama` — uses the Ollama /api/generate API as an inference solution\n- `vllm` — uses your installed vllm version to perform inference\n- `transformers` — uses your installed transformers version to perform local inference\n\n```bash\nusage: python -m gpt_oss.responses_api.serve [-h] [--checkpoint FILE] [--port PORT] [--inference-backend BACKEND]\n\nResponses API server\n\noptions:\n  -h, --help                    show this help message and exit\n  --checkpoint FILE             Path to the SafeTensors checkpoint\n  --port PORT                   Port to run the server on\n  --inference-backend BACKEND   Inference backend to use\n```\n\n### Codex\n\nWe support [codex](https://github.com/openai/codex) as a client for gpt-oss. To run the 20b version, set this to `~/.codex/config.toml`:\n\n```\ndisable_response_storage = true\nshow_reasoning_content = true\n\n[model_providers.local]\nname = \"local\"\nbase_url = \"http://localhost:11434/v1\"\n\n[profiles.oss]\nmodel = \"gpt-oss:20b\"\nmodel_provider = \"local\"\n```\n\nThis will work with any chat completions-API compatible server listening on port 11434, like ollama. Start the server and point codex to the oss model:\n\n```\nollama run gpt-oss:20b\ncodex -p oss\n```\n\n## Tools\n\n### Browser\n\n> [!WARNING]\n> This implementation is purely for educational purposes and should not be used in production. You should implement your own equivalent of the [`YouComBackend`](gpt_oss/tools/simple_browser/backend.py) class with your own browsing environment. Currently we have available `YouComBackend` and `ExaBackend`. \n\nBoth gpt-oss models were trained with the capability to browse using the `browser` tool that exposes the following three methods:\n\n- `search` to search for key phrases\n- `open` to open a particular page\n- `find` to look for contents on a page\n\n#### Usage\n\nTo enable the browser tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_browser_tool()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example:\n\n```python\nimport datetime\nfrom gpt_oss.tools.simple_browser import SimpleBrowserTool\nfrom gpt_oss.tools.simple_browser.backend import YouComBackend\nfrom openai_harmony import SystemContent, Message, Conversation, Role, load_harmony_encoding, HarmonyEncodingName\n\nencoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n\n# Depending on the choice of the browser backend you need corresponding env variables setup\n# In case you use You.com backend requires you to have set the YDC_API_KEY environment variable,\n# while for Exa you might need EXA_API_KEY environment variable set\nbackend = YouComBackend(\n    source=\"web\",\n)\n# backend = ExaBackend(\n#  source=\"web\",\n# )\nbrowser_tool = SimpleBrowserTool(backend=backend)\n\n# create a basic system prompt\nsystem_message_content = SystemContent.new().with_conversation_start_date(\n    datetime.datetime.now().strftime(\"%Y-%m-%d\")\n)\n\n# if you want to use the browser tool\nif use_browser_tool:\n    # enables the tool\n    system_message_content = system_message_content.with_tools(browser_tool.tool_config)\n    # alternatively you could use the following if your tool is not stateless\n    system_message_content = system_message_content.with_browser_tool()\n\n# construct the system message\nsystem_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)\n\n# create the overall prompt\nmessages = [system_message, Message.from_role_and_content(Role.USER, \"What's the weather in SF?\")]\nconversation = Conversation.from_messages(messages)\n\n# convert to tokens\ntoken_ids = encoding.render_conversation_for_completion(conversation, Role.ASSISTANT)\n\n# perform inference\n# ...\n\n# parse the output\nmessages = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT)\nlast_message = messages[-1]\nif last_message.recipient.startswith(\"browser\"):\n  # perform browser call\n  response_messages = await browser_tool.process(last_message)\n\n  # extend the current messages and run inference again\n  messages.extend(response_messages)\n```\n\n#### Details\n\nTo control the context window size this tool uses a scrollable window of text that the model can interact with. So it might fetch the first 50 lines of a page and then scroll to the next 20 lines after that. The model has also been trained to then use citations from this tool in its answers.\n\nTo improve performance the tool caches requests so that the model can revisit a different part of a page without having to reload the page. For that reason you should create a new browser instance for every request.\n\n### Python\n\nThe model was trained to use a python tool to perform calculations and other actions as part of its chain-of-thought. During the training the model used a stateful tool which makes running tools between CoT loops easier. This reference implementation, however, uses a stateless mode. As a result the PythonTool defines its own tool description to override the definition in [`openai-harmony`][harmony].\n\n> [!WARNING]\n> This implementation runs in a permissive Docker container which could be problematic in cases like prompt injections. It's serving as an example and you should consider implementing your own container restrictions in production.\n\n#### Usage\n\nTo enable the python tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_python()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example:\n\n```python\nimport datetime\nfrom gpt_oss.tools.python_docker.docker_tool import PythonTool\nfrom openai_harmony import SystemContent, Message, Conversation, Role, load_harmony_encoding, HarmonyEncodingName\n\nencoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n\npython_tool = PythonTool()\n\n# create a basic system prompt\nsystem_message_content = SystemContent.new().with_conversation_start_date(\n    datetime.datetime.now().strftime(\"%Y-%m-%d\")\n)\n\n# if you want to use the python tool\nif use_python_tool:\n    # enables the tool making sure that the prompt gets set with the stateless tool description\n    system_message_content = system_message_content.with_tools(python_tool.tool_config)\n    # alternatively you could use the following if your tool is not stateless\n    system_message_content = system_message_content.with_python()\n\n# construct the system message\nsystem_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)\n\n# create the overall prompt\nmessages = [system_message, Message.from_role_and_content(Role.USER, \"What's the square root of 9001?\")]\nconversation = Conversation.from_messages(messages)\n\n# convert to tokens\ntoken_ids = encoding.render_conversation_for_completion(conversation, Role.ASSISTANT)\n\n# perform inference\n# ...\n\n# parse the output\nmessages = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT)\nlast_message = messages[-1]\nif last_message.recipient == \"python\":\n  # perform python call\n  response_messages = await python_tool.process(last_message)\n\n  # extend the current messages and run inference again\n  messages.extend(response_messages)\n```\n\n### Apply Patch\n\n`apply_patch` can be used to create, update or delete files locally.\n\n## Other details\n\n### Precision format\n\nWe released the models with native quantization support. Specifically, we use [MXFP4](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) for the linear projection weights in the MoE layer. We store the MoE tensor in two parts:\n\n- `tensor.blocks` stores the actual fp4 values. We pack every two values in one `uint8` value.\n- `tensor.scales` stores the block scale. The block scaling is done among the last dimension for all MXFP4 tensors.\n\nAll other tensors will be in BF16. We also recommend using BF16 as the activation precision for the model.\n\n### Recommended Sampling Parameters\n\nWe recommend sampling with `temperature=1.0` and `top_p=1.0`.\n\n## Contributing\n\nThe reference implementations in this repository are meant as a starting point and inspiration. Outside of bug fixes we do not intend to accept new feature contributions. If you build implementations based on this code such as new tool implementations you are welcome to contribute them to the [`awesome-gpt-oss.md`](./awesome-gpt-oss.md) file.\n\n[harmony]: https://github.com/openai/harmony\n\n## Citation\n\n```bibtex\n@misc{openai2025gptoss120bgptoss20bmodel,\n      title={gpt-oss-120b & gpt-oss-20b Model Card}, \n      author={OpenAI},\n      year={2025},\n      eprint={2508.10925},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL},\n      url={https://arxiv.org/abs/2508.10925}, \n}\n```\n"
  },
  {
    "path": "USAGE_POLICY",
    "content": "We aim for our tools to be used safely, responsibly, and democratically, while maximizing your control over how you use them. By using OpenAI gpt-oss-120b and gpt-oss-20b, you agree to comply with all applicable law."
  },
  {
    "path": "_build/gpt_oss_build_backend/__init__.py",
    "content": "\"\"\"In-tree PEP 517 backend package for gpt-oss.\"\"\" "
  },
  {
    "path": "_build/gpt_oss_build_backend/backend.py",
    "content": "\"\"\"\nBuild backend for gpt-oss that supports two modes:\n\n1) Default (pure wheel for PyPI)\n   - Delegates to setuptools.build_meta.\n   - Produces a py3-none-any wheel so PyPI accepts it (no linux_x86_64 tag).\n\n2) Optional Metal/C extension build (local only)\n   - If the environment variable GPTOSS_BUILD_METAL is set to a truthy value\n     (1/true/on/yes), delegates to scikit_build_core.build.\n   - Dynamically injects build requirements (scikit-build-core, cmake, ninja,\n     pybind11) only for this mode.\n\nWhy this is needed\n- PyPI rejects Linux wheels tagged linux_x86_64; manylinux/musllinux is required\n  for binary wheels. We ship a pure wheel by default, but still allow developers\n  to build/install the native Metal backend locally when needed.\n\nTypical usage\n- Publish pure wheel: `python -m build` (do not set GPTOSS_BUILD_METAL).\n- Local Metal dev: `GPTOSS_BUILD_METAL=1 pip install -e \".[metal]\"`.\n- CI: keep GPTOSS_BUILD_METAL unset for releases; set it in internal jobs that\n  exercise the extension.\n\nNotes\n- The base package remains importable without the extension. The Metal backend\n  is only used when `gpt_oss.metal` is explicitly imported.\n- This file is discovered via `backend-path = [\"_build\"]` and\n  `build-backend = \"gpt_oss_build_backend.backend\"` in pyproject.toml.\n\"\"\"\nimport os\nfrom importlib import import_module\nfrom typing import Any, Mapping, Sequence\n\n\nTRUE_VALUES = {\"1\", \"true\", \"TRUE\", \"on\", \"ON\", \"yes\", \"YES\"}\n\n\ndef _use_metal_backend() -> bool:\n    return str(os.environ.get(\"GPTOSS_BUILD_METAL\", \"\")).strip() in TRUE_VALUES\n\n\ndef _setuptools_backend():\n    from setuptools import build_meta as _bm  # type: ignore\n\n    return _bm\n\n\ndef _scikit_build_backend():\n    return import_module(\"scikit_build_core.build\")\n\n\ndef _backend():\n    return _scikit_build_backend() if _use_metal_backend() else _setuptools_backend()\n\n\n# Required PEP 517 hooks\n\ndef build_wheel(\n    wheel_directory: str,\n    config_settings: Mapping[str, Any] | None = None,\n    metadata_directory: str | None = None,\n) -> str:\n    return _backend().build_wheel(wheel_directory, config_settings, metadata_directory)\n\n\ndef build_sdist(\n    sdist_directory: str, config_settings: Mapping[str, Any] | None = None\n) -> str:\n    return _backend().build_sdist(sdist_directory, config_settings)\n\n\ndef prepare_metadata_for_build_wheel(\n    metadata_directory: str, config_settings: Mapping[str, Any] | None = None\n) -> str:\n    # Fallback if backend doesn't implement it\n    be = _backend()\n    fn = getattr(be, \"prepare_metadata_for_build_wheel\", None)\n    if fn is None:\n        # setuptools exposes it; scikit-build-core may not. Defer to building a wheel for metadata.\n        return _setuptools_backend().prepare_metadata_for_build_wheel(\n            metadata_directory, config_settings\n        )\n    return fn(metadata_directory, config_settings)\n\n\n# Optional hooks\n\ndef build_editable(\n    editable_directory: str, config_settings: Mapping[str, Any] | None = None, metadata_directory: str | None = None\n) -> str:\n    be = _backend()\n    fn = getattr(be, \"build_editable\", None)\n    if fn is None:\n        # setuptools implements build_editable; if not available, raise the standard error\n        raise RuntimeError(\"Editable installs not supported by the selected backend\")\n    return fn(editable_directory, config_settings)\n\n\ndef get_requires_for_build_wheel(\n    config_settings: Mapping[str, Any] | None = None,\n) -> Sequence[str]:\n    if _use_metal_backend():\n        # Add dynamic build requirements only when building the Metal backend\n        return [\n            \"scikit-build-core>=0.10\",\n            \"pybind11>=2.12\",\n            \"cmake>=3.26\",\n            \"ninja\",\n        ]\n    # setuptools usually returns []\n    return list(_setuptools_backend().get_requires_for_build_wheel(config_settings))\n\n\ndef get_requires_for_build_sdist(\n    config_settings: Mapping[str, Any] | None = None,\n) -> Sequence[str]:\n    # No special requirements for SDist\n    be = _backend()\n    fn = getattr(be, \"get_requires_for_build_sdist\", None)\n    if fn is None:\n        return []\n    return list(fn(config_settings))\n\n\ndef get_requires_for_build_editable(\n    config_settings: Mapping[str, Any] | None = None,\n) -> Sequence[str]:\n    if _use_metal_backend():\n        return [\n            \"scikit-build-core>=0.10\",\n            \"pybind11>=2.12\",\n            \"cmake>=3.26\",\n            \"ninja\",\n        ]\n    be = _setuptools_backend()\n    fn = getattr(be, \"get_requires_for_build_editable\", None)\n    if fn is None:\n        return []\n    return list(fn(config_settings)) "
  },
  {
    "path": "awesome-gpt-oss.md",
    "content": "![gpt-oss](./docs/gpt-oss.svg)\n\n# Awesome gpt-oss\n\nThis is a list of guides and resources to help you get started with the gpt-oss models.\n\n- [Inference](#inference)\n  - [Local](#local)\n  - [Server](#server)\n  - [Cloud](#cloud)\n- [Examples / Tutorials](#examples--tutorials)\n- [Tools](#tools)\n- [Training](#training)\n\n## Inference\n\n### Local\n\n- Ollama\n  - [How to run gpt-oss locally with Ollama](https://cookbook.openai.com/articles/gpt-oss/run-locally-ollama)\n  - [Ollama & gpt-oss launch blog](https://ollama.com/blog/gpt-oss)\n  - [Check out the models Ollama](https://ollama.com/library/gpt-oss)\n- LM Studio\n  - [LM Studio & gpt-oss launch blog](https://lmstudio.ai/blog/gpt-oss)\n  - [Use gpt-oss-20b with LM Studio](https://lmstudio.ai/models/openai/gpt-oss-20b)\n  - [Use gpt-oss-120b with LM Studio](https://lmstudio.ai/models/openai/gpt-oss-120b)\n- Hugging Face & Transformers\n  - [How to run gpt-oss with Transformers](https://cookbook.openai.com/articles/gpt-oss/run-transformers)\n  - [Hugging Face & gpt-oss launch blog](https://huggingface.co/blog/welcome-openai-gpt-oss)\n  - [Collection of Hugging Face examples](https://github.com/huggingface/gpt-oss-recipes)\n- NVIDIA\n  - [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss)\n- AMD\n  - [Running gpt-oss models on AMD Ryzen AI Processors and Radeon Graphics Cards](https://www.amd.com/en/blogs/2025/how-to-run-openai-gpt-oss-20b-120b-models-on-amd-ryzen-ai-radeon.html)\n  - [Running gpt-oss on STX Halo and Radeon dGPUs using Lemonade](https://lemonade-server.ai/news/gpt-oss.html)\n- llama.cpp\n  - [Running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396)\n  - [Running gpt-oss with Unsloth GGUFs](https://docs.unsloth.ai/new/gpt-oss-how-to-run-and-fine-tune#run-gpt-oss-20b)\n\n### Server\n\n- vLLM\n  - [How to run gpt-oss with vLLM](https://cookbook.openai.com/articles/gpt-oss/run-vllm)\n  - [vLLM & gpt-oss recipies](https://docs.vllm.ai/projects/recipes/en/latest/OpenAI/GPT-OSS.html)\n- NVIDIA\n  - [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/run-nvidia)\n  - [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md)\n- AMD\n  - [Running the Latest Open Models from OpenAI on AMD AI Hardware](https://rocm.blogs.amd.com/ecosystems-and-partners/openai-day-0/README.html)\n\n### Cloud\n\n- Groq\n  - [Groq & gpt-oss launch blog](https://groq.com/blog/day-zero-support-for-openai-open-models)\n  - [gpt-oss-120b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-120b)\n  - [gpt-oss-20b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-20b)\n  - [gpt-oss with built-in web search on GroqCloud](https://console.groq.com/docs/browser-search)\n  - [gpt-oss with built-in code execution on GroqCloud](https://console.groq.com/docs/code-execution)\n  - [Responses API on Groq](https://console.groq.com/docs/responses-api)\n- NVIDIA\n  - [NVIDIA launch blog post](https://blogs.nvidia.com/blog/openai-gpt-oss/)\n  - [NVIDIA & gpt-oss developer launch blog post](https://developer.nvidia.com/blog/delivering-1-5-m-tps-inference-on-nvidia-gb200-nvl72-nvidia-accelerates-openai-gpt-oss-models-from-cloud-to-edge/)\n  - Use [gpt-oss-120b](https://build.nvidia.com/openai/gpt-oss-120b) and [gpt-oss-20b](https://build.nvidia.com/openai/gpt-oss-20b) on NVIDIA's Cloud\n- Cloudflare\n  - [Cloudflare & gpt-oss launch blog post](https://blog.cloudflare.com/openai-gpt-oss-on-workers-ai)\n  - [gpt-oss-120b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-120b)\n  - [gpt-oss-20b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-20b)\n- AMD\n  - [gpt-oss-120B on AMD MI300X](https://huggingface.co/spaces/amd/gpt-oss-120b-chatbot)\n- AWS\n  - Deploy via Tensorfuse: [Deploy gpt-oss for both 20b and 120b models on AWS EKS](https://tensorfuse.io/docs/guides/modality/text/openai_oss)\n  - [AWS launch blog post](https://aws.amazon.com/blogs/aws/openai-open-weight-models-now-available-on-aws/)\n- Google Colab\n  - [gpt-oss-20b inference notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/GPT_OSS_MXFP4_(20B)-Inference.ipynb)\n\n## Examples & Tutorials\n\n- [OpenAI harmony response format](https://cookbook.openai.com/articles/openai-harmony)\n\n## Tools\n\n- [Example `python` tool for gpt-oss](./gpt_oss/tools/python_docker/)\n- [Example `browser` tool for gpt-oss](./gpt_oss/tools/simple_browser/)\n\n## Training\n\n- [Hugging Face TRL examples](https://github.com/huggingface/gpt-oss-recipes)\n- [LlamaFactory examples](https://llamafactory.readthedocs.io/en/latest/advanced/best_practice/gpt-oss.html)\n- [Unsloth examples](https://docs.unsloth.ai/basics/gpt-oss-how-to-run-and-fine-tune)\n\n### Reinforcement Learning\n- [Auto solving the 2048 game](https://github.com/openai/gpt-oss/blob/main/examples/reinforcement-fine-tuning.ipynb)\n\n## Contributing\n\nFeel free to open a PR to add your own guides and resources on how to run gpt-oss. We will try to review it and add it here.\n"
  },
  {
    "path": "compatibility-test/.gitignore",
    "content": "# Logs\nlogs\n*.log\nnpm-debug.log*\nyarn-debug.log*\nyarn-error.log*\nlerna-debug.log*\n\n# Diagnostic reports (https://nodejs.org/api/report.html)\nreport.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json\n\n# Runtime data\npids\n*.pid\n*.seed\n*.pid.lock\n\n# Directory for instrumented libs generated by jscoverage/JSCover\nlib-cov\n\n# Coverage directory used by tools like istanbul\ncoverage\n*.lcov\n\n# nyc test coverage\n.nyc_output\n\n# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)\n.grunt\n\n# Bower dependency directory (https://bower.io/)\nbower_components\n\n# node-waf configuration\n.lock-wscript\n\n# Compiled binary addons (https://nodejs.org/api/addons.html)\nbuild/Release\n\n# Dependency directories\nnode_modules/\njspm_packages/\n\n# Snowpack dependency directory (https://snowpack.dev/)\nweb_modules/\n\n# TypeScript cache\n*.tsbuildinfo\n\n# Optional npm cache directory\n.npm\n\n# Optional eslint cache\n.eslintcache\n\n# Optional stylelint cache\n.stylelintcache\n\n# Optional REPL history\n.node_repl_history\n\n# Output of 'npm pack'\n*.tgz\n\n# Yarn Integrity file\n.yarn-integrity\n\n# dotenv environment variable files\n.env\n.env.*\n!.env.example\n\n# parcel-bundler cache (https://parceljs.org/)\n.cache\n.parcel-cache\n\n# Next.js build output\n.next\nout\n\n# Nuxt.js build / generate output\n.nuxt\ndist\n\n# Gatsby files\n.cache/\n# Comment in the public line in if your project uses Gatsby and not Next.js\n# https://nextjs.org/blog/next-9-1#public-directory-support\n# public\n\n# vuepress build output\n.vuepress/dist\n\n# vuepress v2.x temp and cache directory\n.temp\n.cache\n\n# Sveltekit cache directory\n.svelte-kit/\n\n# vitepress build output\n**/.vitepress/dist\n\n# vitepress cache directory\n**/.vitepress/cache\n\n# Docusaurus cache and generated files\n.docusaurus\n\n# Serverless directories\n.serverless/\n\n# FuseBox cache\n.fusebox/\n\n# DynamoDB Local files\n.dynamodb/\n\n# Firebase cache directory\n.firebase/\n\n# TernJS port file\n.tern-port\n\n# Stores VSCode versions used for testing VSCode extensions\n.vscode-test\n\n# yarn v3\n.pnp.*\n.yarn/*\n!.yarn/patches\n!.yarn/plugins\n!.yarn/releases\n!.yarn/sdks\n!.yarn/versions\n\n# Vite logs files\nvite.config.js.timestamp-*\nvite.config.ts.timestamp-*\n\nrollout_*.jsonl\nanalysis_*.json"
  },
  {
    "path": "compatibility-test/README.md",
    "content": "# API Compatibility Test\n\nThis script uses the Agents SDK in TypeScript and the underlying OpenAI client to verify the shape of the API calls but also whether the API performs tool calling.\n\n## What it tests\n\n1.\n\n## How to run\n\n0. Run `npm install` in this directory.\n1. Update `providers.ts` to create an entry for the API to test. Change `vllm` to the provider name of your choice. Use `chat` for Chat Completions tests and `responses` for Responses API tests.\n2. Run an initial quick test to make sure things work. This will only run one test\n\n```\nnpm start -- --provider <name> -n 1 -k 1\n```\n\n3. Run the full test (runs each test 5 times to test consistency)\n\n```\nnpm start -- --provider <name> -k 5\n```\n\n## Considerations\n\n1. The tests will fail if the API shape does not match the expected behavior\n2. Events in the chat API are currently not tested\n3. If the schema validation succeeds but the input is wrong the test will still pass for this test. That's because it's likely more of a prompt engineering issue or a validator issue than an API issue as it still nailed the input\n"
  },
  {
    "path": "compatibility-test/analysis.ts",
    "content": "export function analyze(caseResults: any[], tries: number) {\n  // Group results by unique task: test_case + apiType\n  type TaskKey = string;\n  const taskKeyFor = (r: any): TaskKey =>\n    `${r.test_case}::${r.result?.apiType}`;\n\n  const successesByTask: Map<TaskKey, Map<number, boolean>> = new Map();\n\n  // Count wrong-input tool calls (schema correct but incorrect arguments)\n  let wrongInputToolCalls = 0;\n\n  // Count invalid response shapes per API type\n  const totalByApiType: Record<string, number> = {};\n  const invalidByApiType: Record<string, number> = {};\n\n  for (const r of caseResults) {\n    if (!r?.result || typeof r.result.apiType !== \"string\") continue;\n\n    // Parse attempt index from run_id `${i}_${k}` safely\n    let attemptIndex: number | undefined;\n    if (typeof r.run_id === \"string\") {\n      const parts = r.run_id.split(\"_\");\n      const k = Number(parts[1]);\n      if (Number.isFinite(k)) attemptIndex = k;\n    }\n\n    const key = taskKeyFor(r);\n    if (!successesByTask.has(key)) successesByTask.set(key, new Map());\n    if (attemptIndex != null) {\n      successesByTask.get(key)!.set(attemptIndex, Boolean(r.success));\n    }\n\n    const d = r.result.toolCallingDetails ?? {};\n    const calledToolAtLeastOnce = Boolean(d.calledToolAtLeastOnce);\n    const calledToolWithRightSchema = Boolean(d.calledToolWithRightSchema);\n    const calledToolWithRightArguments = Boolean(\n      d.calledToolWithRightArguments\n    );\n    if (\n      calledToolAtLeastOnce &&\n      calledToolWithRightSchema &&\n      !calledToolWithRightArguments\n    ) {\n      wrongInputToolCalls++;\n    }\n\n    // Track invalid/total per apiType for response shape\n    const apiType = r.result.apiType as string;\n    totalByApiType[apiType] = (totalByApiType[apiType] ?? 0) + 1;\n    const isValidResponse = r.result.validResponse === true;\n    if (!isValidResponse) {\n      invalidByApiType[apiType] = (invalidByApiType[apiType] ?? 0) + 1;\n    }\n  }\n\n  const totalTasks = successesByTask.size;\n\n  // Compute pass@k and pass^k for k = 1..tries\n  const passAtKByK: number[] = [];\n  const passHatKByK: number[] = [];\n\n  for (let k = 1; k <= tries; k++) {\n    let tasksSuccessfulK = 0; // any success in first k attempts\n    let tasksAllSuccessfulK = 0; // all success in first k attempts\n\n    for (const [, attemptsMap] of successesByTask) {\n      let anySuccess = false;\n      let allSuccess = true;\n      for (let i = 0; i < k; i++) {\n        const v = attemptsMap.get(i) === true;\n        anySuccess = anySuccess || v;\n        if (!v) allSuccess = false;\n      }\n      if (anySuccess) tasksSuccessfulK++;\n      if (allSuccess) tasksAllSuccessfulK++;\n    }\n\n    const passAtK = totalTasks > 0 ? tasksSuccessfulK / totalTasks : 0;\n    const passHatK = totalTasks > 0 ? tasksAllSuccessfulK / totalTasks : 0;\n    passAtKByK.push(passAtK);\n    passHatKByK.push(passHatK);\n  }\n\n  // Convenience: final k=tries values\n  const passAtK = passAtKByK[tries - 1] ?? 0;\n  const passHatK = passHatKByK[tries - 1] ?? 0;\n\n  return {\n    totalTasks,\n    passAtKByK,\n    passHatKByK,\n    passAtK,\n    passHatK,\n    wrongInputToolCalls,\n    // New stats for invalid response shapes per API\n    invalidByApiType,\n    totalByApiType,\n  };\n}\n\nexport function printAnalysis(\n  stats: ReturnType<typeof analyze>,\n  caseResults: any[],\n  provider: string,\n  selectedLines: string[],\n  tries: number,\n  skipped: number,\n  analysisFile: string\n) {\n  const formatPerK = (arr: number[]) =>\n    Array.from({ length: tries }, (_, i) => {\n      const v = arr[i] ?? 0;\n      return `${i + 1}=${v.toFixed(3)}`;\n    }).join(\", \");\n\n  console.log(\"Summary:\");\n  console.log(`  Provider: ${provider}`);\n  console.log(`  Total input cases: ${selectedLines.length}`);\n  console.log(`  Tries: ${tries}`);\n  console.log(`  Total tasks: ${stats.totalTasks}`);\n  console.log(`  Total runs: ${caseResults.length}`);\n  // Conditionally print invalid response shape stats per API type\n  if ((stats.totalByApiType[\"responses\"] ?? 0) > 0) {\n    const bad = stats.invalidByApiType[\"responses\"] ?? 0;\n    const tot = stats.totalByApiType[\"responses\"] ?? 0;\n    console.log(`  Invalid Responses API responses: ${bad} (out of ${tot})`);\n  }\n  if ((stats.totalByApiType[\"chat\"] ?? 0) > 0) {\n    const bad = stats.invalidByApiType[\"chat\"] ?? 0;\n    const tot = stats.totalByApiType[\"chat\"] ?? 0;\n    console.log(\n      `  Invalid Chat Completions API responses: ${bad} (out of ${tot})`\n    );\n  }\n  console.log(`  pass@k (k=1..${tries}): ${formatPerK(stats.passAtKByK)}`);\n  console.log(`  pass^k (k=1..${tries}): ${formatPerK(stats.passHatKByK)}`);\n  console.log(`  pass@k (k=${tries}): ${stats.passAtK.toFixed(3)}`);\n  console.log(`  pass^k (k=${tries}): ${stats.passHatK.toFixed(3)}`);\n  console.log(`  Wrong-input tool calls: ${stats.wrongInputToolCalls}`);\n  console.log(`  Invalid cases.jsonl lines: ${skipped}`);\n  console.log(`  Analysis written to ${analysisFile}`);\n}\n"
  },
  {
    "path": "compatibility-test/cases.jsonl",
    "content": "{\"tool_name\":\"get_system_health\",\"input\":\"Hey, quick check: is everything up and running?\",\"expected_arguments\":\"{}\"}\n{\"tool_name\":\"get_system_health\",\"input\":\"Status report please.\",\"expected_arguments\":\"{}\"}\n{\"tool_name\":\"get_system_health\",\"input\":\"Can you confirm the LLM health before we start?\",\"expected_arguments\":\"{}\"}\n{\"tool_name\":\"get_system_health\",\"input\":\"Need a health snapshot.\",\"expected_arguments\":\"{}\"}\n{\"tool_name\":\"get_system_health\",\"input\":\"Hi, what's the current system health?\",\"expected_arguments\":\"{}\"}\n{\"tool_name\":\"markdown_to_html\",\"input\":\"Convert this markdown to HTML:\\n\\n# Title\\n\\nSome *italic* text.\",\"expected_arguments\":\"{\\\"markdown\\\":\\\"# Title\\\\n\\\\nSome *italic* text.\\\"}\"}\n{\"tool_name\":\"markdown_to_html\",\"input\":\"Hey, could you turn `## Docs` into HTML?\",\"expected_arguments\":\"{\\\"markdown\\\":\\\"## Docs\\\"}\"}\n{\"tool_name\":\"markdown_to_html\",\"input\":\"Please render the following markdown:\\n\\n- item 1\\n- item 2\",\"expected_arguments\":\"{\\\"markdown\\\":\\\"- item 1\\\\n- item 2\\\"}\"}\n{\"tool_name\":\"markdown_to_html\",\"input\":\"I have `**bold**` markdown; give me HTML.\",\"expected_arguments\":\"{\\\"markdown\\\":\\\"**bold**\\\"}\"}\n{\"tool_name\":\"markdown_to_html\",\"input\":\"Markdown to HTML: > quote\",\"expected_arguments\":\"{\\\"markdown\\\":\\\"> quote\\\"}\"}\n{\"tool_name\":\"detect_language\",\"input\":\"Hey, what language is this: 'Buenos días, ¿cómo estás?'\",\"expected_arguments\":\"{\\\"text\\\":\\\"Buenos días, ¿cómo estás?\\\"}\"}\n{\"tool_name\":\"detect_language\",\"input\":\"Identify the language: \\\"Guten Morgen\\\"\",\"expected_arguments\":\"{\\\"text\\\":\\\"Guten Morgen\\\"}\"}\n{\"tool_name\":\"detect_language\",\"input\":\"Language detection needed: こんにちは、お元気ですか？\",\"expected_arguments\":\"{\\\"text\\\":\\\"こんにちは、お元気ですか？\\\"}\"}\n{\"tool_name\":\"detect_language\",\"input\":\"Detect language for: 'Привет, как дела?'\",\"expected_arguments\":\"{\\\"text\\\":\\\"Привет, как дела?\\\"}\"}\n{\"tool_name\":\"detect_language\",\"input\":\"What language is 'Bonjour tout le monde'?\",\"expected_arguments\":\"{\\\"text\\\":\\\"Bonjour tout le monde\\\"}\"}\n{\"tool_name\":\"generate_chart\",\"input\":\"Plot a simple line chart for these points: (1,2),(2,4),(3,9).\",\"expected_arguments\":\"{\\\"data\\\":[[1,2],[2,4],[3,9]],\\\"chart_type\\\":\\\"line\\\"}\"}\n{\"tool_name\":\"generate_chart\",\"input\":\"Hey, can I get a bar chart of my sales: 10, 20, 30 across Q1–Q3?\",\"expected_arguments\":\"{\\\"data\\\":[[1,10],[2,20],[3,30]],\\\"chart_type\\\":\\\"bar\\\",\\\"title\\\":\\\"Quarterly Sales\\\"}\"}\n{\"tool_name\":\"generate_chart\",\"input\":\"Make a scatter chart titled 'Experiment' with x label Time and y label Value for data [ [0,1], [1,1.5], [2,2.2] ].\",\"expected_arguments\":\"{\\\"data\\\":[[0,1],[1,1.5],[2,2.2]],\\\"chart_type\\\":\\\"scatter\\\",\\\"title\\\":\\\"Experiment\\\",\\\"x_label\\\":\\\"Time\\\",\\\"y_label\\\":\\\"Value\\\"}\"}\n{\"tool_name\":\"generate_chart\",\"input\":\"Create a line chart of temperatures 70,72,68,65 over 4 days, label x as 'Day'.\",\"expected_arguments\":\"{\\\"data\\\":[[1,70],[2,72],[3,68],[4,65]],\\\"chart_type\\\":\\\"line\\\",\\\"x_label\\\":\\\"Day\\\"}\"}\n{\"tool_name\":\"generate_chart\",\"input\":\"Visualize visits per day with a bar chart; numbers: 100,150,120.\",\"expected_arguments\":\"{\\\"data\\\":[[1,100],[2,150],[3,120]],\\\"chart_type\\\":\\\"bar\\\",\\\"title\\\":\\\"Daily Visits\\\",\\\"y_label\\\":\\\"Visitors\\\"}\"}\n{\"tool_name\":\"query_database\",\"input\":\"Give me the ids and emails from users table, limit 5.\",\"expected_arguments\":\"{\\\"table\\\":\\\"users\\\",\\\"columns\\\":[\\\"id\\\",\\\"email\\\"],\\\"limit\\\":5}\"}\n{\"tool_name\":\"query_database\",\"input\":\"Hey, fetch order_id and amount from orders where status is 'shipped'.\",\"expected_arguments\":\"{\\\"table\\\":\\\"orders\\\",\\\"columns\\\":[\\\"order_id\\\",\\\"amount\\\"],\\\"filters\\\":\\\"status = 'shipped'\\\"}\"}\n{\"tool_name\":\"query_database\",\"input\":\"Retrieve name and price from products ordered by price descending, top 10 please.\",\"expected_arguments\":\"{\\\"table\\\":\\\"products\\\",\\\"columns\\\":[\\\"name\\\",\\\"price\\\"],\\\"limit\\\":10,\\\"order_by\\\":\\\"price DESC\\\"}\"}\n{\"tool_name\":\"query_database\",\"input\":\"I need the first 3 log entries from audit_log table.\",\"expected_arguments\":\"{\\\"table\\\":\\\"audit_log\\\",\\\"columns\\\":[\\\"id\\\",\\\"timestamp\\\",\\\"action\\\"],\\\"limit\\\":3}\"}\n{\"tool_name\":\"query_database\",\"input\":\"Query the customers table for name, city where city = 'Berlin'.\",\"expected_arguments\":\"{\\\"table\\\":\\\"customers\\\",\\\"columns\\\":[\\\"name\\\",\\\"city\\\"],\\\"filters\\\":\\\"city = 'Berlin'\\\"}\"}\n{\"tool_name\":\"get_weather\",\"input\":\"What's the weather in San Francisco right now?\",\"expected_arguments\":\"{\\\"location\\\":\\\"San Francisco\\\"}\"}\n{\"tool_name\":\"get_weather\",\"input\":\"Weather for Tokyo, please.\",\"expected_arguments\":\"{\\\"location\\\":\\\"Tokyo\\\"}\"}\n{\"tool_name\":\"get_weather\",\"input\":\"Get me the current weather for 10001.\",\"expected_arguments\":\"{\\\"location\\\":\\\"10001\\\"}\"}\n{\"tool_name\":\"get_weather\",\"input\":\"How's the weather in Paris today?\",\"expected_arguments\":\"{\\\"location\\\":\\\"Paris\\\"}\"}\n{\"tool_name\":\"get_weather\",\"input\":\"Check the weather for Sydney.\",\"expected_arguments\":\"{\\\"location\\\":\\\"Sydney\\\"}\"}\n"
  },
  {
    "path": "compatibility-test/index.ts",
    "content": "import { parseArgs } from \"node:util\";\nimport { createWriteStream } from \"node:fs\";\nimport { readFile, writeFile } from \"node:fs/promises\";\nimport path from \"node:path\";\nimport process from \"node:process\";\nimport { runCase, RunCaseSummary } from \"./runCase\";\nimport { Listr, ListrTaskWrapper } from \"listr2\";\nimport { analyze, printAnalysis } from \"./analysis\";\n\nfunction formatTimestamp(d: Date): string {\n  const pad = (n: number) => String(n).padStart(2, \"0\");\n  const yyyy = d.getFullYear();\n  const mm = pad(d.getMonth() + 1);\n  const dd = pad(d.getDate());\n  const hh = pad(d.getHours());\n  const mi = pad(d.getMinutes());\n  const ss = pad(d.getSeconds());\n  return `${yyyy}${mm}${dd}_${hh}${mi}${ss}`;\n}\n\nasync function main() {\n  const args = parseArgs({\n    options: {\n      cases: { type: \"string\", short: \"c\", default: \"cases.jsonl\" },\n      provider: { type: \"string\", short: \"p\", default: \"openai\" },\n      streaming: { type: \"boolean\", short: \"s\", default: false },\n      maxTurns: { type: \"string\", short: \"t\", default: \"10\" },\n      n: { type: \"string\", short: \"n\" },\n      strict: { type: \"boolean\", short: \"s\", default: false },\n      tries: { type: \"string\", short: \"k\", default: \"1\" },\n    },\n  });\n  const casesPathArg = args.values.cases;\n  const provider = args.values.provider as string;\n  const streaming = Boolean(args.values.streaming);\n  const maxTurns = Number(args.values.maxTurns ?? 10);\n  const nRaw = args.values.n as string | undefined;\n  const triesRaw = args.values.tries as string | undefined;\n  const tries = triesRaw != null ? Number(triesRaw) : 1;\n  const limit = nRaw != null ? Number(nRaw) : undefined;\n  if (limit != null && (!Number.isFinite(limit) || limit <= 0)) {\n    console.error(\"--n must be a positive integer\");\n    process.exitCode = 1;\n    return;\n  }\n\n  if (!casesPathArg) {\n    console.error(\"--cases is required (path to JSONL file)\");\n    process.exitCode = 1;\n    return;\n  }\n\n  const casesPath = path.isAbsolute(casesPathArg)\n    ? casesPathArg\n    : path.join(process.cwd(), casesPathArg);\n\n  const timestamp = formatTimestamp(new Date());\n  const defaultFilename = `rollout_${provider}_${timestamp}.jsonl`;\n  const outputFile = path.join(process.cwd(), defaultFilename);\n  const analysisFile = path.join(\n    process.cwd(),\n    `analysis_${provider}_${timestamp}.json`\n  );\n\n  let fileContent: string;\n  try {\n    fileContent = await readFile(casesPath, \"utf8\");\n  } catch (err: any) {\n    console.error(\n      `Failed to read cases file at ${casesPath}: ${err?.message ?? err}`\n    );\n    process.exitCode = 1;\n    return;\n  }\n\n  const lines = fileContent\n    .split(/\\r?\\n/)\n    .map((l) => l.trim())\n    .filter((l) => l.length > 0);\n\n  const selectedLines =\n    typeof limit === \"number\" ? lines.slice(0, limit) : lines;\n\n  const out = createWriteStream(outputFile, { flags: \"w\", encoding: \"utf8\" });\n\n  const writeLine = (obj: any) =>\n    new Promise<void>((resolve, reject) => {\n      const str = JSON.stringify(obj) + \"\\n\";\n      out.write(str, (err) => (err ? reject(err) : resolve()));\n    });\n\n  // Accumulators for post-run analysis\n  let skipped = 0; // invalid JSON lines\n  const caseResults: Array<{\n    run_id: string;\n    success: boolean;\n    provider: string;\n    test_case: number;\n    tool_name: string;\n    input: string;\n    result: RunCaseSummary;\n  }> = [];\n\n  async function processIndex(\n    i: number,\n    k: number,\n    task: ListrTaskWrapper<any, any, any>\n  ) {\n    const line = selectedLines[i];\n    let caseObj: any;\n    try {\n      caseObj = JSON.parse(line);\n    } catch (err: any) {\n      console.error(\n        `Skipping invalid JSON on line ${i + 1}: ${err?.message ?? err}`\n      );\n      skipped++;\n      return;\n    }\n\n    try {\n      const summaries = await runCase(provider, caseObj, {\n        maxTurns,\n        streaming,\n        strict: args.values.strict,\n      });\n\n      for (const summary of summaries) {\n        const record = {\n          run_id: `${i}_${k}`,\n          success: summary.success,\n          provider,\n          test_case: i,\n          tool_name: caseObj.tool_name,\n          input: caseObj.input,\n          result: summary,\n        };\n        task.output = `Case ${i} (attempt ${k + 1}): ${\n          summary.success ? \"Success\" : \"Failed\"\n        } ${summary.toolCallingDetails.warning || \"\"}`;\n        caseResults.push(record);\n        await writeLine(record);\n      }\n    } catch (err: any) {\n      const record = {\n        provider,\n        test_case: i,\n        tool_name: caseObj?.tool_name,\n        input: caseObj?.input,\n        expected_output: caseObj?.expected_output,\n        instructions: caseObj?.instructions,\n        error: String(err?.message ?? err),\n      };\n      await writeLine(record);\n      task.output = `Case ${i} failed: ${err?.message ?? err}`;\n    }\n  }\n\n  const listr = new Listr<{\n    output: string;\n  }>(\n    selectedLines.flatMap((line, index) => {\n      return Array.from({ length: tries }, (_, attempt) => ({\n        title: `Processing case ${index} (attempt ${attempt + 1})`,\n        task: async (_, task) => {\n          await processIndex(index, attempt, task);\n        },\n        rendererOptions: { persistentOutput: true },\n      }));\n    }),\n    {\n      concurrent: 5,\n    }\n  );\n\n  await listr.run();\n\n  await new Promise((resolve) => out.end(resolve));\n  console.log(`Results written to ${outputFile}`);\n  const stats = analyze(caseResults, tries);\n  await writeFile(analysisFile, JSON.stringify(stats, null, 2), \"utf8\");\n  printAnalysis(\n    stats,\n    caseResults,\n    provider,\n    selectedLines,\n    tries,\n    skipped,\n    analysisFile\n  );\n}\n\nmain().catch((err) => {\n  console.error(err);\n  process.exitCode = 1;\n});\n"
  },
  {
    "path": "compatibility-test/package.json",
    "content": "{\n  \"type\": \"module\",\n  \"dependencies\": {\n    \"@openai/agents\": \"^0.0.15\",\n    \"ajv\": \"^8.17.1\",\n    \"listr2\": \"^9.0.1\"\n  },\n  \"scripts\": {\n    \"start\": \"tsx index.ts\"\n  }\n}\n"
  },
  {
    "path": "compatibility-test/providers.ts",
    "content": "export const PROVIDERS = {\n  vllm: {\n    apiBaseUrl: \"http://localhost:8000/v1\",\n    apiKey: \"vllm\",\n    apiType: [\"responses\", \"chat\"], // choose from responses, chat, or both\n    modelName: \"openai/gpt-oss-120b\",\n    providerDetails: {\n      // add any provider-specific details here. These will be passed as part of every request\n      // for example to fix the provider for openrouter, you can do:\n      // provider: {\n      //   only: [\"example\"],\n      // },\n    },\n  },\n};\n"
  },
  {
    "path": "compatibility-test/runCase.ts",
    "content": "import {\n  Agent,\n  Runner,\n  OpenAIResponsesModel,\n  OpenAIChatCompletionsModel,\n  RunResult,\n  StreamedRunResult,\n  FunctionTool,\n  setTracingDisabled,\n} from \"@openai/agents\";\nimport { Ajv } from \"ajv\";\nimport { OpenAI } from \"openai\";\nimport { PROVIDERS } from \"./providers\";\nimport { TOOLS_MAP } from \"./tools\";\n\nsetTracingDisabled(true);\n\nconst ajv = new Ajv();\n\nexport type Case = {\n  tool_name: string;\n  input: string;\n  expected_arguments: string;\n  instructions?: string;\n};\n\n// Summary shape for each apiType\nexport type RunCaseSummary = {\n  apiType: string;\n  success: boolean;\n  validResponse: boolean;\n  validEvents?: boolean;\n  details: Record<string, any>;\n  history: any[];\n  successToolCall: boolean;\n  toolCallingDetails: Record<string, any>;\n};\n\nexport async function runCase(\n  provider: string,\n  caseData: Case,\n  {\n    maxTurns,\n    streaming,\n    strict,\n  }: { maxTurns: number; streaming: boolean; strict: boolean }\n): Promise<RunCaseSummary[]> {\n  const config = PROVIDERS[provider];\n  if (!config) {\n    throw new Error(\n      `Provider ${provider} not found. Valid providers are: ${Object.keys(\n        PROVIDERS\n      ).join(\", \")}`\n    );\n  }\n\n  const agent = new Agent({\n    name: caseData.tool_name,\n    instructions: caseData.instructions,\n    tools: [TOOLS_MAP[caseData.tool_name]],\n  });\n\n  const client = new OpenAI({\n    apiKey: config.apiKey,\n    baseURL: config.apiBaseUrl,\n  });\n\n  const summaries: RunCaseSummary[] = [];\n\n  for (const apiType of config.apiType) {\n    const runner = new Runner({\n      model:\n        apiType === \"responses\"\n          ? new OpenAIResponsesModel(client, config.modelName)\n          : new OpenAIChatCompletionsModel(client, config.modelName),\n      modelSettings: {\n        providerData: config.providerDetails ?? {},\n      },\n    });\n\n    let result: RunResult<any, any> | StreamedRunResult<any, any>;\n    let streamedEvents: any[] | undefined = undefined;\n    if (streaming) {\n      result = await runner.run(agent, caseData.input, {\n        stream: streaming,\n        maxTurns: maxTurns,\n      });\n      if (result instanceof StreamedRunResult) {\n        // Collect streaming events if applicable\n        streamedEvents = [];\n        for await (const event of result) {\n          if (event.type === \"raw_model_stream_event\") {\n            if (event.data.type === \"model\") {\n              streamedEvents.push(event.data.event);\n            }\n          }\n        }\n        await result.completed;\n      }\n    } else {\n      result = await runner.run(agent, caseData.input, {\n        maxTurns: maxTurns,\n      });\n    }\n\n    const { success: successToolCall, details: toolCallingDetails } =\n      testToolCall(apiType, caseData, result, strict);\n\n    const { validResponse, details } = testOutputData(\n      apiType,\n      result.rawResponses,\n      streaming\n    );\n\n    const { validEvents, details: eventsDetails } = streaming\n      ? testEvents(apiType, streamedEvents)\n      : { validEvents: true, details: {} };\n\n    let success = successToolCall && validResponse;\n    if (streaming) {\n      success = success && validEvents;\n    }\n    const summary: RunCaseSummary = {\n      apiType,\n      success,\n      validResponse,\n      validEvents,\n      details: {\n        ...details,\n        ...eventsDetails,\n      },\n      history: result?.rawResponses.map((entry) => entry.providerData) ?? [],\n      successToolCall,\n      toolCallingDetails,\n    };\n\n    summaries.push(summary);\n  }\n\n  return summaries;\n}\n\nfunction testToolCall(apiType, caseData, result, strict) {\n  let details: Record<string, boolean | string> = {};\n  result.newItems.forEach((item) => {\n    // for this test for now we only care if the tool is called at least once\n    if (details.calledToolAtLeastOnce) {\n      return;\n    }\n\n    const isToolCall = item.type === \"tool_call_item\";\n    if (isToolCall) {\n      if (item.rawItem.type === \"function_call\") {\n        if (item.rawItem.name === caseData.tool_name) {\n          const validate = ajv.compile(\n            (TOOLS_MAP[caseData.tool_name] as FunctionTool).parameters\n          );\n          const valid = validate(JSON.parse(item.rawItem.arguments));\n          details.calledToolWithRightSchema = valid;\n          details.calledToolAtLeastOnce = true;\n\n          if (details.calledToolWithRightSchema) {\n            const parsedArguments = JSON.parse(item.rawItem.arguments);\n            const expectedArguments = JSON.parse(caseData.expected_arguments);\n            details.calledToolWithRightArguments = deepEqual(\n              parsedArguments,\n              expectedArguments\n            );\n            if (!details.calledToolWithRightArguments) {\n              if (details.calledToolWithRightSchema) {\n                details.warning = `Tool call with wrong arguments but correct schema. Check logs for full details. Not failing this test. Parsed: ${JSON.stringify(\n                  parsedArguments\n                )} Expected: ${JSON.stringify(expectedArguments)}`;\n              }\n              details.actualArguments = parsedArguments;\n              details.expectedArguments = expectedArguments;\n            }\n          }\n        }\n      }\n    }\n  });\n\n  return {\n    success:\n      !!details.calledToolAtLeastOnce &&\n      !!details.calledToolWithRightSchema &&\n      (!strict || !!details.calledToolWithRightArguments),\n    details,\n  };\n}\n\nfunction testEvents(apiType, events) {\n  // In an ideal world we would check all the events to follow and reconstruct the final response\n  // and then compare it against the final response in the response.completed event\n  // for now we just check that certain events are present\n\n  let details: Record<string, boolean> = {};\n  let validEvents: boolean = false;\n\n  if (apiType === \"chat\") {\n    let hasReasoningDeltas = false;\n    for (const event of events) {\n      hasReasoningDeltas =\n        hasReasoningDeltas ||\n        (typeof event.choices[0].delta.reasoning === \"string\" &&\n          event.choices[0].delta.reasoning.length > 0);\n    }\n    details.hasReasoningDeltas = hasReasoningDeltas;\n    validEvents = hasReasoningDeltas;\n  }\n\n  if (apiType === \"responses\") {\n    let hasReasoningDeltaEvents = false;\n    let hasReasoningDoneEvents = false;\n    for (const event of events) {\n      if (event.type === \"raw_model_stream_event\") {\n        if (event.data.type === \"model\") {\n          if (event.data.event.type === \"response.reasoning_text.delta\") {\n            hasReasoningDeltaEvents = true;\n          }\n          if (event.data.event.type === \"response.reasoning_text.done\") {\n            hasReasoningDoneEvents = true;\n          }\n        }\n      }\n    }\n\n    details.hasReasoningDeltaEvents = hasReasoningDeltaEvents;\n    details.hasReasoningDoneEvents = hasReasoningDoneEvents;\n    validEvents =\n      details.hasReasoningDeltaEvents && details.hasReasoningDoneEvents;\n  }\n\n  return {\n    validEvents,\n    details,\n  };\n}\n\nfunction testOutputData(apiType, rawResponses, streaming) {\n  let details: Record<string, boolean> = {};\n  let validResponse: boolean = false;\n\n  if (apiType === \"chat\") {\n    for (const response of rawResponses) {\n      if (streaming && !response.providerData) {\n        // with Chat Completions we don't have a final response object that's native so we skip this test\n        return {\n          validResponse: true,\n          details: {\n            skippedBecauseStreaming: true,\n          },\n        };\n      }\n\n      // this is the actual HTTP response from the provider\n      // Since it's not guaranteed that every response has a reasoning field, we check if it's present\n      // at least once across all responses\n      const data = response.providerData;\n      const message = data.choices[0].message;\n      if (message.role === \"assistant\" && !message.refusal) {\n        details.hasReasoningField =\n          details.hasReasoningField ||\n          (\"reasoning\" in message && typeof message.reasoning === \"string\");\n        details.hasReasoningContentField =\n          details.hasReasoningContentField ||\n          (\"reasoning_content\" in message &&\n            typeof message.reasoning_content === \"string\");\n\n        validResponse =\n          validResponse ||\n          (details.hasReasoningField && message.reasoning.length > 0);\n      }\n    }\n  } else if (apiType === \"responses\") {\n    // this is the actual HTTP response from the provider\n    const data = rawResponses[0].providerData;\n    for (const item of data.output) {\n      // Since it's not guaranteed that every response has a reasoning field, we check if it's present\n      // at least once across all responses\n\n      if (item.type === \"reasoning\") {\n        details.hasReasoningContentArray = Array.isArray(item.content);\n        details.hasReasoningContentArrayLength = item.content.length > 0;\n        details.hasReasoningContentArrayItemType = item.content.every(\n          (item) => item.type === \"reasoning_text\"\n        );\n        details.hasReasoningContentArrayItemText = item.content.every(\n          (item) => item.text.length > 0\n        );\n\n        validResponse =\n          details.hasReasoningContentArray &&\n          details.hasReasoningContentArrayLength &&\n          details.hasReasoningContentArrayItemType &&\n          details.hasReasoningContentArrayItemText;\n      }\n    }\n  }\n\n  return {\n    validResponse,\n    details,\n  };\n}\n\nfunction deepEqual(a: any, b: any): boolean {\n  if (a === b) return true;\n  if (typeof a !== typeof b) return false;\n  if (a && b && typeof a === \"object\") {\n    if (Array.isArray(a) !== Array.isArray(b)) return false;\n    if (Array.isArray(a)) {\n      if (a.length !== b.length) return false;\n      for (let i = 0; i < a.length; i++) {\n        if (!deepEqual(a[i], b[i])) return false;\n      }\n      return true;\n    } else {\n      const aKeys = Object.keys(a);\n      const bKeys = Object.keys(b);\n      if (aKeys.length !== bKeys.length) return false;\n      for (const key of aKeys) {\n        if (!b.hasOwnProperty(key)) return false;\n        if (!deepEqual(a[key], b[key])) return false;\n      }\n      return true;\n    }\n  }\n  return false;\n}\n"
  },
  {
    "path": "compatibility-test/tools.ts",
    "content": "import { Tool, tool } from \"@openai/agents\";\n\nfunction convertToTool(toolData: any) {\n  return tool({\n    name: toolData.name,\n    description: toolData.description,\n    parameters: toolData.parameters,\n    execute: async (parameters) => {\n      return toolData.output;\n    },\n    strict: false,\n  });\n}\n\nexport const TOOLS = [\n  {\n    type: \"function\",\n    name: \"get_weather\",\n    description: \"Get the weather for a given location\",\n    parameters: {\n      type: \"object\",\n      properties: {\n        location: {\n          type: \"string\",\n          description: \"The location to get the weather for\",\n        },\n      },\n      required: [\"location\"],\n      additionalProperties: false,\n    },\n    output: '{\"weather\":\"sunny\"}',\n  },\n  {\n    type: \"function\",\n    name: \"get_system_health\",\n    description:\n      \"Returns the current health status of the LLM runtime—use before critical operations to verify the service is live.\",\n    parameters: { type: \"object\", properties: {} },\n    output: '{\"status\":\"ok\",\"uptime_seconds\":372045}',\n  },\n  {\n    type: \"function\",\n    name: \"markdown_to_html\",\n    description:\n      \"Converts a Markdown string to sanitized HTML—use when you need browser-renderable output.\",\n    parameters: {\n      type: \"object\",\n      properties: {\n        markdown: { type: \"string\", description: \"Raw Markdown content\" },\n      },\n      required: [\"markdown\"],\n      additionalProperties: false,\n    },\n    output: '{\"html\":\"<h1>Hello World</h1><p>This is <em>great</em>.</p>\"}',\n  },\n  {\n    type: \"function\",\n    name: \"detect_language\",\n    description:\n      \"Identifies the ISO language code of the supplied text—use for routing text to language-specific models.\",\n    parameters: {\n      type: \"object\",\n      properties: {\n        text: {\n          type: \"string\",\n          description: \"Text whose language should be detected\",\n        },\n      },\n      required: [\"text\"],\n      additionalProperties: false,\n    },\n    output: '{\"language\":\"de\",\"confidence\":0.98}',\n  },\n  {\n    type: \"function\",\n    name: \"generate_chart\",\n    description:\n      \"Creates a base64-encoded PNG chart from tabular data—use for quick visualizations inside chat.\",\n    parameters: {\n      type: \"object\",\n      properties: {\n        data: {\n          type: \"array\",\n          items: { type: \"array\", items: { type: \"number\" } },\n          description: \"2-D numeric data matrix\",\n        },\n        chart_type: {\n          type: \"string\",\n          enum: [\"line\", \"bar\", \"scatter\"],\n          description: \"Type of chart to generate\",\n        },\n        title: {\n          type: \"string\",\n          description: \"Chart title\",\n          default: \"\",\n        },\n        x_label: {\n          type: \"string\",\n          description: \"Label for the x-axis\",\n          default: \"\",\n        },\n        y_label: {\n          type: \"string\",\n          description: \"Label for the y-axis\",\n          default: \"\",\n        },\n      },\n      required: [\"data\", \"chart_type\"],\n      additionalProperties: false,\n    },\n    output: '{\"image_png_base64\":\"iVBORw0KGgoAAAANSUhEUgAA...\"}',\n  },\n  {\n    type: \"function\",\n    name: \"query_database\",\n    description:\n      \"Runs a parameterized SQL SELECT on the internal analytics DB—use for lightweight data look-ups.\",\n    parameters: {\n      type: \"object\",\n      properties: {\n        table: { type: \"string\", description: \"Table name to query\" },\n        columns: {\n          type: \"array\",\n          items: { type: \"string\" },\n          description: \"Columns to return\",\n        },\n        filters: {\n          type: \"string\",\n          description: \"SQL WHERE clause without the word WHERE\",\n          default: \"\",\n        },\n        limit: {\n          type: \"integer\",\n          minimum: 1,\n          maximum: 10000,\n          description: \"Max rows to return\",\n          default: 100,\n        },\n        order_by: {\n          type: \"string\",\n          description: \"Column to order by (optional)\",\n          default: \"\",\n        },\n      },\n      required: [\"table\", \"columns\"],\n      additionalProperties: false,\n    },\n    output:\n      '{\"rows\":[{\"id\":1,\"email\":\"user@example.com\"},{\"id\":2,\"email\":\"foo@bar.com\"}],\"row_count\":2}',\n  },\n];\n\nexport const TOOLS_MAP = TOOLS.reduce((acc, tool) => {\n  acc[tool.name] = convertToTool(tool);\n  return acc;\n}, {} as Record<string, Tool>);\n"
  },
  {
    "path": "examples/agents-sdk-js/index.ts",
    "content": "import { OpenAI } from \"openai\";\nimport {\n  Agent,\n  run,\n  setDefaultOpenAIClient,\n  setOpenAIAPI,\n  setTracingDisabled,\n  tool,\n  MCPServerStdio,\n} from \"@openai/agents\";\nimport { z } from \"zod\";\nimport path from \"node:path\";\nimport process from \"node:process\";\nimport { styleText } from \"node:util\";\nimport { createInterface } from \"node:readline/promises\";\n\nasync function prompt(question: string) {\n  const rl = createInterface({\n    input: process.stdin,\n    output: process.stdout,\n  });\n  const answer = await rl.question(question);\n  rl.close();\n  return answer;\n}\n\nconst openai = new OpenAI({\n  apiKey: \"local\",\n  baseURL: \"http://localhost:11434/v1\",\n});\n\nconst samplesDir = path.join(process.cwd());\n\nconst mcpServer = new MCPServerStdio({\n  name: \"Filesystem MCP Server, via npx\",\n  fullCommand: `npx -y @modelcontextprotocol/server-filesystem ${samplesDir}`,\n});\n\nawait mcpServer.connect();\n\nsetTracingDisabled(true);\nsetDefaultOpenAIClient(openai);\nsetOpenAIAPI(\"chat_completions\");\n\nconst searchTool = tool({\n  name: \"get_current_weather\",\n  description: \"Get the current weather in a given location\",\n  parameters: z.object({\n    location: z.string(),\n  }),\n  execute: async ({ location }) => {\n    return `The weather in ${location} is sunny.`;\n  },\n});\n\nconst agent = new Agent({\n  name: \"My Agent\",\n  instructions: \"You are a helpful assistant.\",\n  tools: [searchTool],\n  model: \"gpt-oss:20b-test\",\n  mcpServers: [mcpServer],\n});\n\nconst input = await prompt(\"> \");\n\nconst result = await run(agent, input, {\n  stream: true,\n});\n\nfor await (const event of result) {\n  if (event.type === \"raw_model_stream_event\" && event.data.type === \"model\") {\n    if (event.data.event.choices[0].delta.content) {\n      process.stdout.write(event.data.event.choices[0].delta.content);\n    } else if (event.data.event.choices[0].delta.reasoning) {\n      process.stdout.write(event.data.event.choices[0].delta.reasoning);\n    }\n  } else if (\n    event.type === \"run_item_stream_event\" &&\n    event.item.type === \"tool_call_item\" &&\n    event.item.rawItem.type == \"function_call\"\n  ) {\n    console.log(\n      `\\nCalling ${event.item.rawItem.name} with: ${event.item.rawItem.arguments}`\n    );\n  }\n}\n\nconsole.log(\"\\n\");\nawait result.completed;\nawait mcpServer.close();\n"
  },
  {
    "path": "examples/agents-sdk-js/package.json",
    "content": "{\n  \"type\": \"module\",\n  \"name\": \"agents-sdk\",\n  \"version\": \"1.0.0\",\n  \"main\": \"index.js\",\n  \"scripts\": {\n    \"start\": \"tsx index.ts\",\n    \"test\": \"echo \\\"Error: no test specified\\\" && exit 1\"\n  },\n  \"keywords\": [],\n  \"author\": \"\",\n  \"license\": \"ISC\",\n  \"description\": \"\",\n  \"dependencies\": {\n    \"@openai/agents\": \"^0.0.14\",\n    \"tsx\": \"^4.20.3\",\n    \"typescript\": \"^5.8.3\",\n    \"zod\": \"^3.25.67\"\n  }\n}\n"
  },
  {
    "path": "examples/agents-sdk-python/example.py",
    "content": "import asyncio\nfrom pathlib import Path\nimport shutil\n\nfrom openai import AsyncOpenAI\nfrom agents import (\n    Agent,\n    ItemHelpers,\n    Runner,\n    set_default_openai_api,\n    set_default_openai_client,\n    set_tracing_disabled,\n    function_tool,\n)\nfrom agents.mcp import MCPServerStdio\n\n\nasync def prompt_user(question: str) -> str:\n    \"\"\"Async input prompt function\"\"\"\n    loop = asyncio.get_event_loop()\n    return await loop.run_in_executor(None, input, question)\n\n\nasync def main():\n    # Set up OpenAI client for local server (e.g., Ollama)\n    openai_client = AsyncOpenAI(\n        api_key=\"local\",\n        base_url=\"http://localhost:11434/v1\",\n    )\n\n    # Get current working directory\n    samples_dir = str(Path.cwd())\n\n    # Create MCP server for filesystem operations\n    mcp_server = MCPServerStdio(\n        name=\"Filesystem MCP Server, via npx\",\n        params={\n            \"command\": \"npx\",\n            \"args\": [\n                \"-y\",\n                \"@modelcontextprotocol/server-filesystem\",\n                samples_dir,\n            ],\n        },\n    )\n\n    # Connect to MCP server\n    await mcp_server.connect()\n\n    # Configure agents SDK\n    set_tracing_disabled(True)\n    set_default_openai_client(openai_client)\n    set_default_openai_api(\"chat_completions\")\n\n    # Define weather tool\n    @function_tool\n    async def get_weather(location: str) -> str:\n        return f\"The weather in {location} is sunny.\"\n\n    # Create agent\n    agent = Agent(\n        name=\"My Agent\",\n        instructions=\"You are a helpful assistant.\",\n        tools=[get_weather],\n        model=\"gpt-oss:20b-test\",\n        mcp_servers=[mcp_server],\n    )\n\n    # Get user input\n    user_input = await prompt_user(\"> \")\n\n    # Run agent with streaming\n    result = Runner.run_streamed(agent, user_input)\n\n    # Process streaming results\n    async for event in result.stream_events():\n        if event.type == \"raw_response_event\":\n            continue\n        elif event.type == \"agent_updated_stream_event\":\n            print(f\"Agent updated: {event.new_agent.name}\")\n        elif event.type == \"run_item_stream_event\":\n            if event.item.type == \"tool_call_item\":\n                print(\"-- Tool was called\")\n            elif event.item.type == \"tool_call_output_item\":\n                print(f\"-- Tool output: {event.item.output}\")\n            elif event.item.type == \"message_output_item\":\n                print(\n                    f\"-- Message output:\\n {ItemHelpers.text_message_output(event.item)}\"\n                )\n            else:\n                pass\n\n    print(\"=== Run complete ===\")\n\n\nif __name__ == \"__main__\":\n\n    if not shutil.which(\"npx\"):\n        raise RuntimeError(\n            \"npx is not installed. Please install it with `npm install -g npx`.\"\n        )\n    asyncio.run(main())\n"
  },
  {
    "path": "examples/agents-sdk-python/pyproject.toml",
    "content": "[project]\nname = \"agents-sdk-python\"\nversion = \"0.1.0\"\ndescription = \"Add your description here\"\nreadme = \"README.md\"\nrequires-python = \">=3.12\"\ndependencies = [\n    \"openai-agents>=0.2.4\",\n]\n"
  },
  {
    "path": "examples/gradio/gradio_chat.py",
    "content": "import json\nimport requests\nimport gradio as gr\n\nDEFAULT_FUNCTION_PROPERTIES = \"\"\"\n{\n    \"type\": \"object\",\n    \"properties\": {\n        \"location\": {\n            \"type\": \"string\",\n            \"description\": \"The city and state, e.g. San Francisco, CA\"\n        }\n    },\n    \"required\": [\"location\"]\n}\n\"\"\".strip()\n\ndef chat_with_model(message, history, model_choice, instructions, effort, use_functions, \n                   function_name, function_description, function_parameters,\n                   use_browser_search, temperature, max_output_tokens, debug_mode):\n    \n    if not message.strip():\n        return history, \"\"\n    \n    # Append user message and empty assistant placeholder (idiomatic Gradio pattern)\n    history = history + [[message, \"\"]]\n    \n    # Build messages list from history (excluding the empty assistant placeholder)\n    messages = []\n    \n    # Convert history to messages format (excluding the last empty assistant message)\n    for user_msg, assistant_msg in history[:-1]:\n        if user_msg:\n            messages.append({\n                \"type\": \"message\",\n                \"role\": \"user\", \n                \"content\": [{\"type\": \"input_text\", \"text\": user_msg}]\n            })\n        if assistant_msg:\n            messages.append({\n                \"type\": \"message\",\n                \"role\": \"assistant\",\n                \"content\": [{\"type\": \"output_text\", \"text\": assistant_msg}]\n            })\n    \n    # Add current user message\n    messages.append({\n        \"type\": \"message\",\n        \"role\": \"user\",\n        \"content\": [{\"type\": \"input_text\", \"text\": message}]\n    })\n    \n    # Prepare tools\n    tools = []\n    if use_functions:\n        try:\n            tools.append({\n                \"type\": \"function\",\n                \"name\": function_name,\n                \"description\": function_description,\n                \"parameters\": json.loads(function_parameters),\n            })\n        except json.JSONDecodeError:\n            pass\n    \n    if use_browser_search:\n        tools.append({\"type\": \"browser_search\"})\n    \n    # Get URL based on model (matching streamlit logic)\n    options = [\"large\", \"small\"]\n    URL = (\"http://localhost:8081/v1/responses\" if model_choice == options[1] \n           else \"http://localhost:8000/v1/responses\")\n    \n    try:\n        response = requests.post(\n            URL,\n            json={\n                \"input\": messages,\n                \"stream\": True,\n                \"instructions\": instructions,\n                \"reasoning\": {\"effort\": effort},\n                \"metadata\": {\"__debug\": debug_mode},\n                \"tools\": tools,\n                \"temperature\": temperature,\n                \"max_output_tokens\": max_output_tokens,\n            },\n            stream=True,\n        )\n        \n        full_content = \"\"\n        text_delta = \"\"\n        current_output_index = 0\n        in_reasoning = False\n        \n        for line in response.iter_lines(decode_unicode=True):\n            if not line or not line.startswith(\"data:\"):\n                continue\n            data_str = line[len(\"data:\"):].strip()\n            if not data_str:\n                continue\n            \n            try:\n                data = json.loads(data_str)\n            except Exception:\n                continue\n            \n            event_type = data.get(\"type\", \"\")\n            output_index = data.get(\"output_index\", 0)\n            \n            if event_type == \"response.output_item.added\":\n                current_output_index = output_index\n                output_type = data.get(\"item\", {}).get(\"type\", \"message\")\n                text_delta = \"\"\n                \n                if output_type == \"reasoning\":\n                    if not in_reasoning:\n                        full_content += \"🤔 **Thinking...**\\n\"\n                        in_reasoning = True\n                elif output_type == \"message\":\n                    if in_reasoning:\n                        full_content += \"\\n\\n\"\n                        in_reasoning = False\n                \n            elif event_type == \"response.reasoning_text.delta\":\n                delta = data.get(\"delta\", \"\")\n                full_content += delta\n                \n                # Update last assistant message (idiomatic Gradio pattern)\n                history[-1][1] = full_content\n                yield history, \"\"\n                \n            elif event_type == \"response.output_text.delta\":\n                delta = data.get(\"delta\", \"\")\n                full_content += delta\n                \n                # Update last assistant message (idiomatic Gradio pattern)  \n                history[-1][1] = full_content\n                yield history, \"\"\n                \n            elif event_type == \"response.output_item.done\":\n                item = data.get(\"item\", {})\n                if item.get(\"type\") == \"function_call\":\n                    function_call_text = f\"\\n\\n🔨 Called `{item.get('name')}`\\n**Arguments**\\n```json\\n{item.get('arguments', '')}\\n```\"\n                    full_content += function_call_text\n                    \n                    # Update last assistant message (idiomatic Gradio pattern)\n                    history[-1][1] = full_content\n                    yield history, \"\"\n                    \n                elif item.get(\"type\") == \"web_search_call\":\n                    web_search_text = f\"\\n\\n🌐 **Web Search**\\n```json\\n{json.dumps(item.get('action', {}), indent=2)}\\n```\\n✅ Done\"\n                    full_content += web_search_text\n                    \n                    # Update last assistant message (idiomatic Gradio pattern)\n                    history[-1][1] = full_content\n                    yield history, \"\"\n                    \n            elif event_type == \"response.completed\":\n                response_data = data.get(\"response\", {})\n                if debug_mode:\n                    debug_info = response_data.get(\"metadata\", {}).get(\"__debug\", \"\")\n                    if debug_info:\n                        full_content += f\"\\n\\n**Debug**\\n```\\n{debug_info}\\n```\"\n                        \n                        # Update last assistant message (idiomatic Gradio pattern)\n                        history[-1][1] = full_content\n                        yield history, \"\"\n                break\n        \n        # Return final history and empty string to clear textbox\n        return history, \"\"\n        \n    except Exception as e:\n        error_message = f\"❌ Error: {str(e)}\"\n        history[-1][1] = error_message\n        return history, \"\"\n\n\n# Create the Gradio interface\nwith gr.Blocks(title=\"💬 Chatbot\") as demo:\n    gr.Markdown(\"# 💬 Chatbot\")\n    \n    with gr.Row():\n        with gr.Column(scale=3):\n            chatbot = gr.Chatbot(height=500)\n            \n            with gr.Row():\n                msg = gr.Textbox(placeholder=\"Type a message...\", scale=4, show_label=False)\n                send_btn = gr.Button(\"Send\", scale=1)\n            \n            clear_btn = gr.Button(\"Clear Chat\")\n        \n        with gr.Column(scale=1):\n            model_choice = gr.Radio([\"large\", \"small\"], value=\"small\", label=\"Model\")\n            \n            instructions = gr.Textbox(\n                label=\"Instructions\", \n                value=\"You are a helpful assistant that can answer questions and help with tasks.\",\n                lines=3\n            )\n            \n            effort = gr.Radio([\"low\", \"medium\", \"high\"], value=\"medium\", label=\"Reasoning effort\")\n            \n            gr.Markdown(\"#### Functions\")\n            use_functions = gr.Checkbox(label=\"Use functions\", value=False)\n            \n            with gr.Column(visible=False) as function_group:\n                function_name = gr.Textbox(label=\"Function name\", value=\"get_weather\")\n                function_description = gr.Textbox(\n                    label=\"Function description\", \n                    value=\"Get the weather for a given city\"\n                )\n                function_parameters = gr.Textbox(\n                    label=\"Function parameters\", \n                    value=DEFAULT_FUNCTION_PROPERTIES,\n                    lines=6\n                )\n            \n            # Conditional browser search (matching Streamlit logic)\n            # In Streamlit: if \"show_browser\" in st.query_params:\n            # For Gradio, we'll always show it (simplified)\n            gr.Markdown(\"#### Built-in Tools\") \n            use_browser_search = gr.Checkbox(label=\"Use browser search\", value=False)\n            \n            temperature = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label=\"Temperature\")\n            max_output_tokens = gr.Slider(1000, 20000, value=1024, step=100, label=\"Max output tokens\")\n            \n            debug_mode = gr.Checkbox(label=\"Debug mode\", value=False)\n    \n    # Event handlers\n    def toggle_function_group(use_funcs):\n        return gr.update(visible=use_funcs)\n    \n    use_functions.change(toggle_function_group, use_functions, function_group)\n    \n    # Chat functionality\n    inputs = [msg, chatbot, model_choice, instructions, effort, use_functions, \n              function_name, function_description, function_parameters,\n              use_browser_search, temperature, max_output_tokens, debug_mode]\n    \n    msg.submit(chat_with_model, inputs, [chatbot, msg])\n    send_btn.click(chat_with_model, inputs, [chatbot, msg])\n    clear_btn.click(lambda: [], outputs=chatbot)\n\n\nif __name__ == \"__main__\":\n    demo.launch()"
  },
  {
    "path": "examples/reinforcement-fine-tuning.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"view-in-github\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"<a href=\\\"https://colab.research.google.com/github/openai/gpt-oss/blob/main/examples/reinforcement-fine-tuning.ipynb\\\" target=\\\"_parent\\\"><img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Free Colab\\\"/></a>\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"hzPgFeIkZn9q\"\n      },\n      \"source\": [\n        \"# Make gpt-oss play games with Reinforcement Learning\\n\",\n        \"\\n\",\n        \"This notebook demonstrates how you make `gpt-oss` play the 2048 game autonomously by using reinforcement learning (RL).\\n\",\n        \"\\n\",\n        \"We will train `gpt-oss-20b` using [Unsloth](https://github.com/unslothai/unsloth) to develop a strategy for playing 2048. The strategy will run until the game ends, and the model will be rewarded or penalized based on whether it wins or loses.\\n\",\n        \"\\n\",\n        \"<img src=\\\"https://upload.wikimedia.org/wikipedia/commons/thumb/f/f9/2048_win.png/500px-2048_win.png\\\" width=300 />\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"31KIMLJLnHET\"\n      },\n      \"source\": [\n        \"# Installation\\n\",\n        \"To run `gpt-oss-20b` RL on a free Google Colab instance, we’ll use the GRPO algorithm along with [Unsloth](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning), an open-source tool that enables less VRAM usage and faster training.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"CGoDZwcunHEU\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"%%capture\\n\",\n        \"!pip install --upgrade -qqq uv\\n\",\n        \"try: import numpy; get_numpy = f\\\"numpy=={numpy.__version__}\\\"\\n\",\n        \"except: get_numpy = \\\"numpy\\\"\\n\",\n        \"!uv pip install -qqq \\\\\\n\",\n        \"    \\\"torch>=2.8.0\\\" \\\"triton>=3.4.0\\\" {get_numpy} torchvision bitsandbytes \\\"transformers==4.56.2\\\" \\\\\\n\",\n        \"    \\\"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo\\\" \\\\\\n\",\n        \"    \\\"unsloth[base] @ git+https://github.com/unslothai/unsloth\\\" \\\\\\n\",\n        \"    git+https://github.com/triton-lang/triton.git@05b2c186c1b6c9a08375389d5efe9cb4c401c075#subdirectory=python/triton_kernels\\n\",\n        \"!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers\\n\",\n        \"!uv pip install --no-deps trl==0.22.2\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"We'll load gpt-oss-20b and set some parameters:\\n\",\n        \"* `max_seq_length = 768` The maximum context length of the model. Increasing it will use more memory, and 768 was the maximum we found to fit on a free 15GB Tesla T4 machine\\n\",\n        \"* `lora_rank = 4` The larger this number, the smarter the RL process, but the slower and more memory usage\\n\",\n        \"* `load_in_4bit = True` Uses quantization to reduce memory usage by 75% without reducing accuracy that much. `load_in_16bit` will be faster but will need a 80GB GPU (H100, B200)\\n\",\n        \"* `offload_embedding = True` Unsloth optimization which moves the embedding to CPU RAM, reducing VRAM by 1GB\"\n      ],\n      \"metadata\": {\n        \"id\": \"CcLYwLyQLADE\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\",\n          \"height\": 575,\n          \"referenced_widgets\": [\n            \"abe2b0a2913d4633943f44333ae799f8\",\n            \"2c40c6b846924200b29616a590af1672\",\n            \"749e8407a901483c8b513a2fb71596c8\",\n            \"7baca79d720c40b5a923b9717e28c982\",\n            \"68ea891644ca4753a8e1bf278ff47e84\",\n            \"06ab9eaa6f0f48c4b68cff1ca4b9f2fa\",\n            \"d98c2b1e979b4929891a8ee0c11f55df\",\n            \"ef01b874478b4bb497d31d2f8dd6145a\",\n            \"d50ea8cded9848ffa18be1ae6a2559df\",\n            \"ffabf89ecd9d48a5a3fc2a1c855ce080\",\n            \"614c5332c7d045109102a329e7f69dfd\",\n            \"caf742160db041a1b6c2cfdf78f2dc9a\",\n            \"34a9e38b0b454a69a067d1ddadec7626\",\n            \"263b7dc0b3fd465fac89b9266b19d526\",\n            \"5b7af68130f04a63ad3efa3d9f602ebe\",\n            \"2a6aa92676c74509b58373ca604c5b3b\",\n            \"9c4d6839934b4b13952a850d2084d498\",\n            \"c6a1decbc0e7421db622033214913cb9\",\n            \"147743757c804b85af2ef194f5f84e6a\",\n            \"2820e352ab004e818949acc31eb3888d\",\n            \"80fa3aef5e2040d9904c6b87b7214ca0\",\n            \"0f99489932aa409b94ba34764aff19b0\",\n            \"6ab4e5676ad84807a126fffa99f7a0d4\",\n            \"e61ef80398444c13bf7cd20ef21a5057\",\n            \"5ebe7b4e4ed24c53b783ee46377c682d\",\n            \"e0fdef0087bc4a91a11932a2d933c001\",\n            \"596c2a62a635469eb74233ce00586a6f\",\n            \"da4324e287e64e5ba98fc110693066df\",\n            \"8c7c6bb04a3f4a1494b34529f95a195c\",\n            \"51aaa109480d4ae6bd419aea689d22ee\",\n            \"acf4e50a248342f68d26daef21baa419\",\n            \"7d3379cbd27a4218a9d84c5a12f3bb88\",\n            \"7841bc90b6a74120ab3e603c76332a01\",\n            \"3f9b801b52da4eb79f730d87bea5c338\",\n            \"b66c6ded549d4db8a2e5ea8e5016615c\",\n            \"43da5073c3ad4e98a3ade17a0bb3b93d\",\n            \"40365e2c9fef49148e4c93592d458afc\",\n            \"7e9d5212fc7844f286e14b70cbf0bc7a\",\n            \"77d34c0f1de548b4872208a063bb5017\",\n            \"bf96e8666c224c26b0a01451d08e907a\",\n            \"4513a73fa95b41b5b6edadc9143ba9c1\",\n            \"792d75a7d18945e7972826ac5b2ac386\",\n            \"2a6f43b64d164636a2d9708f0190f21b\",\n            \"65c62d2198e64ee4a9e6547c2733135a\",\n            \"219ca32ab51e4b4385b2c1026a78503a\",\n            \"6c2ccfe3363b40b58fc26ea164d4ead4\",\n            \"07f0420c4dfa477caccd7ae96551c2e4\",\n            \"1c96edb2f7c948b9968b1239982af942\",\n            \"d93be4994f104b6e99d89a9e73cd6abd\",\n            \"4da21f53bf7f4e2d8132eb43e6ecc739\",\n            \"735f70fac43449e3974de1b783d56d33\",\n            \"ad75f887a140416abfca615b2fc3c385\",\n            \"dee02a37a6f44f168546ee0077dc20d1\",\n            \"ee23056662ad4b719b65005d776e0e72\",\n            \"87765ca0996b403dbe29deef48d548bf\",\n            \"8db5e86577744ff1a39c8e198eee5dd3\",\n            \"4b9b3fe8dc764eedb9e18f166fe2f548\",\n            \"cca95e973bc445d3811335debf7c446e\",\n            \"e507a46b4c754d9a8aede2aac0d203bc\",\n            \"751a46fbb8e24efabfb381a85c90fbe8\",\n            \"87a808c4d4f54f719adcd29de7206e1b\",\n            \"5f0b2a0e1953406b88af2c884904e2da\",\n            \"2fa84865e9f14c1491402ef81517b4bd\",\n            \"245590db7d374515a428ff4abbd25588\",\n            \"e2973e6c02834a7c9f2f6ce5755f35f0\",\n            \"48741bbdeccb459aa4eea9c61339764b\",\n            \"1183d3f2ad3c4fb0af1d925b5f9e3efe\",\n            \"9cc51d8029eb4217bc37daa918649692\",\n            \"41f13d2f023e405180689e03bc2c32a1\",\n            \"247484c0bf5945bcb4627b48928366c8\",\n            \"14c0f20a9ab341ee966fe77815099ff0\",\n            \"a219f3b89a34443abe612846676f9356\",\n            \"152d7bf2a74f400db3d3ecaa719ef8d1\",\n            \"36676899a61f4be4b631f6271f6ecec9\",\n            \"77ecad9f150c430fa85f5833d97c42df\",\n            \"cef064f1c55f41bf957fc4623260fdb4\",\n            \"37cbe8800af04a42a0355922969b6393\",\n            \"f8dacdab001d4db0b6b3776ac7d3634a\",\n            \"5a59fb5f7acf4213847c985e66c9ee3c\",\n            \"ae6d42fb84fc4984af1d4430acdcd3c9\",\n            \"02d120e49f2c4f95a6090b1d8d521767\",\n            \"8f1e6c36b84c4115a671dcb9ade41c8b\",\n            \"81a728910a2341a785a6f252bbb371f7\",\n            \"69a8d50f11244ba688c183d14d2395ec\",\n            \"350f29f737534bfba4258bc31ec274a2\",\n            \"9beac0680e3049dfafcb6ec185fd2265\",\n            \"dbf5ed93dac646ed979fa7a8c569dfe3\",\n            \"4db5ee5b7b674abba75fbce264e6dfa3\",\n            \"0c0c96eeac664f339aa4511bf47087e2\",\n            \"18451e19df5449b1853b5e13dacd19c5\",\n            \"d864d29d02c54ecfaedd7b866a6df8c2\",\n            \"7875163297284832a35aca84cbb105ce\",\n            \"d42d8228ea1247a1a81bb99b18c4640c\",\n            \"bcda4c9a48e943a6a0ef812fcd64a6db\",\n            \"61e491b843c347b6b2a9948de7caf01d\",\n            \"dee07d33b8de4c3b847fcff670e68102\",\n            \"b07acf871a0a46f1889bfb439d13752b\",\n            \"ba94310dc12a4a258205b14901ad3f94\",\n            \"a93210a691414502ba3c2dff03ffb4ce\",\n            \"fd2fe9ef6da64f72ab29d481d1739f5e\",\n            \"dbfeea8ee2374b8c8fa70431c35f281f\",\n            \"84d27c45065e426badbfcfcdc8ff16b6\",\n            \"fa9ea0d3234e41689c827485d0360885\",\n            \"4cb119127b404f46a53012c62d004e28\",\n            \"d9020a2a2c8440db81d2cfdf0289b667\",\n            \"04d39c4dda9f4a1bb01b8d6320032372\",\n            \"4d67b10ec7794170addb4e968e20f170\",\n            \"55ac5c2a82ee48fe988e1e4f26c168b0\",\n            \"9a079a30b4ae4bbc80122faf83e0ad59\",\n            \"acda8e7582934fecbbf854e66e23f698\",\n            \"4fbc4cfe529d471ba85f3ae8e53b28d6\",\n            \"a0d0fedc5bec4f5b943fddf9a954fbdf\",\n            \"cab602573c6940919f93e59fe6f4838d\",\n            \"51b8f4ce40f94ac39cf44d98f1522ec7\",\n            \"32d6af64f2464cfb965671f2692b4e15\",\n            \"e1e77d98b01f4376a6c075975c27571e\",\n            \"6a47e60b10a6481b94aee021c8dbc7ba\",\n            \"5657a84bf4b74710b2de1a54f9236e39\",\n            \"7bd5d1beeb0e49e293d9f6b91bb6d7fb\",\n            \"60ceb890b5644493a8886d91b9dac461\",\n            \"40138ff29073407abb95f793509fc320\",\n            \"0ac4d8e674804ad6bdc5f2d62f2e0d33\",\n            \"7bfcd9acf29646db8b6123708d1ffe27\",\n            \"5e88d6515f16475fb72d7c153422b591\",\n            \"5e5b77dd649547f896ab306fccc94a4e\",\n            \"a843fa23e6c94fb486bff8764574fdc5\",\n            \"fd0ac7ed3d3146ec85913f4e05c4a2f6\",\n            \"77204d81ff8f4ee585361a503fa647dc\",\n            \"923653dfe90e475a9efa44baf98ba9a0\",\n            \"62600092f8cc43f493b86b0169f67be1\",\n            \"59e46bbe96df4b88ad31c09096ce0e0a\",\n            \"8f5c7b88a2cc4b5abb0814c814833349\"\n          ]\n        },\n        \"id\": \"DkIvEkIIkEyB\",\n        \"outputId\": \"2f85e1d0-8810-4b41-b683-0c33578d991c\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\\n\",\n            \"🦥 Unsloth Zoo will now patch everything to make training faster!\\n\",\n            \"==((====))==  Unsloth 2025.10.1: Fast Gpt_Oss patching. Transformers: 4.56.2.\\n\",\n            \"   \\\\\\\\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.\\n\",\n            \"O^O/ \\\\_/ \\\\    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0\\n\",\n            \"\\\\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]\\n\",\n            \" \\\"-____-\\\"     Free license: http://github.com/unslothai/unsloth\\n\",\n            \"Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\\n\",\n            \"Unsloth: Using float16 precision for gpt_oss won't work! Using float32.\\n\"\n          ]\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"abe2b0a2913d4633943f44333ae799f8\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"model.safetensors.index.json: 0.00B [00:00, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"caf742160db041a1b6c2cfdf78f2dc9a\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"6ab4e5676ad84807a126fffa99f7a0d4\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"model-00001-of-00004.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"3f9b801b52da4eb79f730d87bea5c338\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"model-00004-of-00004.safetensors:   0%|          | 0.00/1.16G [00:00<?, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"219ca32ab51e4b4385b2c1026a78503a\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"model-00002-of-00004.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"8db5e86577744ff1a39c8e198eee5dd3\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"model-00003-of-00004.safetensors:   0%|          | 0.00/3.37G [00:00<?, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"1183d3f2ad3c4fb0af1d925b5f9e3efe\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"f8dacdab001d4db0b6b3776ac7d3634a\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"generation_config.json:   0%|          | 0.00/165 [00:00<?, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Unsloth: Offloading embeddings to RAM to save 1.08 GB.\\n\"\n          ]\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"0c0c96eeac664f339aa4511bf47087e2\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"tokenizer_config.json: 0.00B [00:00, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"fd2fe9ef6da64f72ab29d481d1739f5e\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"tokenizer.json:   0%|          | 0.00/27.9M [00:00<?, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"4fbc4cfe529d471ba85f3ae8e53b28d6\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"special_tokens_map.json:   0%|          | 0.00/446 [00:00<?, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"data\": {\n            \"application/vnd.jupyter.widget-view+json\": {\n              \"model_id\": \"0ac4d8e674804ad6bdc5f2d62f2e0d33\",\n              \"version_major\": 2,\n              \"version_minor\": 0\n            },\n            \"text/plain\": [\n              \"chat_template.jinja: 0.00B [00:00, ?B/s]\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        }\n      ],\n      \"source\": [\n        \"from unsloth import FastLanguageModel\\n\",\n        \"import torch\\n\",\n        \"max_seq_length = 768 # Can increase for longer RL output\\n\",\n        \"lora_rank = 4        # Larger rank = smarter, but slower\\n\",\n        \"model, tokenizer = FastLanguageModel.from_pretrained(\\n\",\n        \"    model_name = \\\"unsloth/gpt-oss-20b\\\", # unsloth/gpt-oss-20b-BF16 for H100s\\n\",\n        \"    max_seq_length = max_seq_length,\\n\",\n        \"    load_in_4bit = True,      # False for LoRA 16bit. Choose False on H100s\\n\",\n        \"    offload_embedding = True, # Reduces VRAM by 1GB\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"TfeUs-lQJDSq\"\n      },\n      \"source\": [\n        \"To do efficient RL, we will use LoRA, which allows us to only add 1 to 5% of extra weights to the model for fine-tuning purposes. This allows us to save memory usage by 60% while retaining most accuracy. Read Unsloth's [gpt-oss RL Guide](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning) for more details.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"8rGa-o3HJCo1\",\n        \"outputId\": \"6dc27dbf-0c60-4996-8e97-932aab7c14fb\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Unsloth: Making `model.base_model.model.model` require gradients\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"model = FastLanguageModel.get_peft_model(\\n\",\n        \"    model,\\n\",\n        \"    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\\n\",\n        \"    target_modules = [\\n\",\n        \"        \\\"q_proj\\\", \\\"k_proj\\\", \\\"v_proj\\\", \\\"o_proj\\\",\\n\",\n        \"        \\\"gate_proj\\\", \\\"up_proj\\\", \\\"down_proj\\\",\\n\",\n        \"    ],\\n\",\n        \"    lora_alpha = lora_rank*2, # *2 speeds up training\\n\",\n        \"    use_gradient_checkpointing = \\\"unsloth\\\", # Reduces memory usage\\n\",\n        \"    random_state = 3407,\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"N0QnO9_YJBOI\"\n      },\n      \"source\": [\n        \"# 2048 game\\n\",\n        \"\\n\",\n        \"We used GPT-5 to create a variant of the 2048 game. It should output the current game board state, and allow us to advance the game board state with 1 action (up, down, left, right).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"form\",\n        \"id\": \"D9CI4jtgL5mw\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@title (Collapsible) 2048 Game Implementation\\n\",\n        \"from dataclasses import dataclass, field\\n\",\n        \"from typing import List, Tuple, Optional\\n\",\n        \"import random\\n\",\n        \"import copy\\n\",\n        \"\\n\",\n        \"def _compress_and_merge_row_left(row: List[int]) -> Tuple[List[int], int, bool]:\\n\",\n        \"    n = len(row)\\n\",\n        \"    tiles = [x for x in row if x != 0]\\n\",\n        \"    gained = 0\\n\",\n        \"    i = 0\\n\",\n        \"    merged = []\\n\",\n        \"    while i < len(tiles):\\n\",\n        \"        if i + 1 < len(tiles) and tiles[i] == tiles[i + 1]:\\n\",\n        \"            v = tiles[i] * 2\\n\",\n        \"            gained += v\\n\",\n        \"            merged.append(v)\\n\",\n        \"            i += 2\\n\",\n        \"        else:\\n\",\n        \"            merged.append(tiles[i])\\n\",\n        \"            i += 1\\n\",\n        \"    merged += [0] * (n - len(merged))\\n\",\n        \"    changed = merged != row\\n\",\n        \"    return merged, gained, changed\\n\",\n        \"\\n\",\n        \"def _move_left(board: List[List[int]]) -> Tuple[List[List[int]], int, bool]:\\n\",\n        \"    changed_any = False\\n\",\n        \"    total_gain = 0\\n\",\n        \"    new_board = []\\n\",\n        \"    for row in board:\\n\",\n        \"        new_row, gained, changed = _compress_and_merge_row_left(row)\\n\",\n        \"        new_board.append(new_row)\\n\",\n        \"        total_gain += gained\\n\",\n        \"        changed_any = changed_any or changed\\n\",\n        \"    return new_board, total_gain, changed_any\\n\",\n        \"\\n\",\n        \"def _move_right(board: List[List[int]]) -> Tuple[List[List[int]], int, bool]:\\n\",\n        \"    changed_any = False\\n\",\n        \"    total_gain = 0\\n\",\n        \"    new_board = []\\n\",\n        \"    for row in board:\\n\",\n        \"        rev = list(reversed(row))\\n\",\n        \"        new_rev, gained, changed = _compress_and_merge_row_left(rev)\\n\",\n        \"        new_row = list(reversed(new_rev))\\n\",\n        \"        new_board.append(new_row)\\n\",\n        \"        total_gain += gained\\n\",\n        \"        changed_any = changed_any or changed\\n\",\n        \"    return new_board, total_gain, changed_any\\n\",\n        \"\\n\",\n        \"def _transpose(board: List[List[int]]) -> List[List[int]]:\\n\",\n        \"    return [list(row) for row in zip(*board)]\\n\",\n        \"\\n\",\n        \"def _move_up(board: List[List[int]]) -> Tuple[List[List[int]], int, bool]:\\n\",\n        \"    t = _transpose(board)\\n\",\n        \"    moved, gain, changed = _move_left(t)\\n\",\n        \"    return _transpose(moved), gain, changed\\n\",\n        \"\\n\",\n        \"def _move_down(board: List[List[int]]) -> Tuple[List[List[int]], int, bool]:\\n\",\n        \"    t = _transpose(board)\\n\",\n        \"    moved, gain, changed = _move_right(t)\\n\",\n        \"    return _transpose(moved), gain, changed\\n\",\n        \"\\n\",\n        \"def _empty_cells(board: List[List[int]]) -> List[Tuple[int, int]]:\\n\",\n        \"    size = len(board)\\n\",\n        \"    return [(r, c) for r in range(size) for c in range(size) if board[r][c] == 0]\\n\",\n        \"\\n\",\n        \"def _can_move(board: List[List[int]]) -> bool:\\n\",\n        \"    if _empty_cells(board):\\n\",\n        \"        return True\\n\",\n        \"    size = len(board)\\n\",\n        \"    for r in range(size):\\n\",\n        \"        for c in range(size - 1):\\n\",\n        \"            if board[r][c] == board[r][c + 1]:\\n\",\n        \"                return True\\n\",\n        \"    for r in range(size - 1):\\n\",\n        \"        for c in range(size):\\n\",\n        \"            if board[r][c] == board[r + 1][c]:\\n\",\n        \"                return True\\n\",\n        \"    return False\\n\",\n        \"\\n\",\n        \"@dataclass\\n\",\n        \"class GameBoard:\\n\",\n        \"    size: int\\n\",\n        \"    seed: Optional[int] = None\\n\",\n        \"    target: int = 2048\\n\",\n        \"    probability_fours: float = 0.10 # originally spawns (4) 10% of the time!\\n\",\n        \"    _rng: random.Random = field(init=False, repr=False)\\n\",\n        \"    _board: List[List[int]] = field(init=False, repr=False)\\n\",\n        \"    _score: int = field(default=0, init=False, repr=False)\\n\",\n        \"    _state: str = field(default=\\\"ongoing\\\", init=False, repr=False)\\n\",\n        \"\\n\",\n        \"    def __post_init__(self):\\n\",\n        \"        if self.size < 2:\\n\",\n        \"            raise ValueError(\\\"Board size must be at least 2.\\\")\\n\",\n        \"        self._rng = random.Random(self.seed)\\n\",\n        \"        self._board = [[0 for _ in range(self.size)] for _ in range(self.size)]\\n\",\n        \"        self._add_random_tile()\\n\",\n        \"        self._add_random_tile()\\n\",\n        \"        self._update_state_after_change()\\n\",\n        \"\\n\",\n        \"    class _BoardView:\\n\",\n        \"        def __init__(self, game: \\\"GameBoard\\\"):\\n\",\n        \"            self._game = game\\n\",\n        \"        def __iter__(self):\\n\",\n        \"            return iter(self._game._board)\\n\",\n        \"        def __len__(self):\\n\",\n        \"            return len(self._game._board)\\n\",\n        \"        def __getitem__(self, idx):\\n\",\n        \"            return self._game._board[idx]\\n\",\n        \"        def __repr__(self) -> str:\\n\",\n        \"            return repr(self._game._board)\\n\",\n        \"        __str__ = __repr__\\n\",\n        \"        def do_action(self, key: str) -> None:\\n\",\n        \"            self._game.do_action(key)\\n\",\n        \"        def state(self) -> str:\\n\",\n        \"            return self._game.state()\\n\",\n        \"        def pretty(self, colors: bool = True, border: bool = True, dot_for_zero: bool = True) -> str:\\n\",\n        \"            return self._game._render_pretty(colors=colors, border=border, dot_for_zero=dot_for_zero)\\n\",\n        \"\\n\",\n        \"    def board(self) -> \\\"_BoardView\\\":\\n\",\n        \"        return GameBoard._BoardView(self)\\n\",\n        \"    def state(self) -> str:\\n\",\n        \"        return self._state\\n\",\n        \"    def score(self) -> int:\\n\",\n        \"        return self._score\\n\",\n        \"    def do_action(self, key: str) -> None:\\n\",\n        \"        if self._state != \\\"ongoing\\\":\\n\",\n        \"            return\\n\",\n        \"        if not isinstance(key, str) or len(key) == 0:\\n\",\n        \"            self._state = \\\"failed\\\"\\n\",\n        \"            return\\n\",\n        \"        k = key.strip().lower()\\n\",\n        \"        if k == \\\"q\\\":\\n\",\n        \"            self._state = \\\"failed\\\"\\n\",\n        \"            return\\n\",\n        \"        move_map = {\\\"a\\\": _move_left, \\\"d\\\": _move_right, \\\"w\\\": _move_up, \\\"s\\\": _move_down}\\n\",\n        \"        if k not in move_map:\\n\",\n        \"            self._state = \\\"failed\\\"\\n\",\n        \"            return\\n\",\n        \"        mover = move_map[k]\\n\",\n        \"        new_board, gain, changed = mover(self._board)\\n\",\n        \"        if changed:\\n\",\n        \"            self._board = new_board\\n\",\n        \"            self._score += gain\\n\",\n        \"            self._add_random_tile()\\n\",\n        \"        self._update_state_after_change()\\n\",\n        \"    def _add_random_tile(self) -> bool:\\n\",\n        \"        empties = _empty_cells(self._board)\\n\",\n        \"        if not empties:\\n\",\n        \"            return False\\n\",\n        \"        r, c = self._rng.choice(empties)\\n\",\n        \"        self._board[r][c] = 4 if self._rng.random() < self.probability_fours else 2\\n\",\n        \"        return True\\n\",\n        \"    def _update_state_after_change(self) -> None:\\n\",\n        \"        if any(self.target in row for row in self._board):\\n\",\n        \"            self._state = \\\"success\\\"\\n\",\n        \"            return\\n\",\n        \"        if not _can_move(self._board):\\n\",\n        \"            self._state = \\\"failed\\\"\\n\",\n        \"            return\\n\",\n        \"        self._state = \\\"ongoing\\\"\\n\",\n        \"    def _render_pretty(self, colors: bool = True, border: bool = True, dot_for_zero: bool = True) -> str:\\n\",\n        \"        \\\"\\\"\\\"\\n\",\n        \"        Pretty-print the board with colors that scale from 0 up to self.target.\\n\",\n        \"        Uses ANSI 256-color codes (works in most terminals). Set colors=False to disable.\\n\",\n        \"        \\\"\\\"\\\"\\n\",\n        \"        import math\\n\",\n        \"\\n\",\n        \"        b = self._board\\n\",\n        \"        mx = max((max(row) for row in b), default=0)\\n\",\n        \"        cell_w = max(3, len(str(mx)))\\n\",\n        \"\\n\",\n        \"        RESET = \\\"\\\\x1b[0m\\\"\\n\",\n        \"\\n\",\n        \"        # A smooth-ish gradient from cool → warm\\n\",\n        \"        # (blue/cyan/green → yellow/orange/red). Tweak or expand as you like.\\n\",\n        \"        GRAD = [33, 39, 45, 51, 50, 49, 48, 47, 46, 82, 118, 154, 190, 226, 220, 214, 208, 202, 196]\\n\",\n        \"        ZERO_FG = 239  # dim gray\\n\",\n        \"\\n\",\n        \"        def color_code(v: int) -> str:\\n\",\n        \"            if not colors:\\n\",\n        \"                return \\\"\\\"\\n\",\n        \"            if v == 0:\\n\",\n        \"                return f\\\"\\\\x1b[38;5;{ZERO_FG}m\\\"\\n\",\n        \"            # Normalize by exponent relative to target: r in [0,1]\\n\",\n        \"            t = max(2, self.target)  # safety; avoid log2(1)\\n\",\n        \"            # Guard: if v is not a power of two or is <1, handle gracefully\\n\",\n        \"            try:\\n\",\n        \"                r = max(0.0, min(1.0, math.log2(v) / math.log2(t)))\\n\",\n        \"            except ValueError:\\n\",\n        \"                r = 0.0\\n\",\n        \"            idx = int(round(r * (len(GRAD) - 1)))\\n\",\n        \"            return f\\\"\\\\x1b[38;5;{GRAD[idx]}m\\\"\\n\",\n        \"\\n\",\n        \"        def fmt(v: int) -> str:\\n\",\n        \"            s = \\\".\\\" if (v == 0 and dot_for_zero) else str(v)\\n\",\n        \"            s = s.rjust(cell_w)\\n\",\n        \"            return color_code(v) + s + (RESET if colors else \\\"\\\")\\n\",\n        \"\\n\",\n        \"        def hline(left: str, mid: str, right: str) -> str:\\n\",\n        \"            return left + mid.join(\\\"─\\\" * cell_w for _ in range(self.size)) + right\\n\",\n        \"\\n\",\n        \"        rows = []\\n\",\n        \"        if border:\\n\",\n        \"            rows.append(hline(\\\"┌\\\", \\\"┬\\\", \\\"┐\\\"))\\n\",\n        \"        for r in range(self.size):\\n\",\n        \"            content = \\\"│\\\".join(fmt(v) for v in b[r])\\n\",\n        \"            rows.append((\\\"│\\\" + content + \\\"│\\\") if border else content)\\n\",\n        \"            if border:\\n\",\n        \"                rows.append(hline(\\\"└\\\" if r == self.size - 1 else \\\"├\\\",\\n\",\n        \"                                \\\"┴\\\" if r == self.size - 1 else \\\"┼\\\",\\n\",\n        \"                                \\\"┘\\\" if r == self.size - 1 else \\\"┤\\\"))\\n\",\n        \"        return \\\"\\\\n\\\".join(rows)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"4BcaLniVKLpa\"\n      },\n      \"source\": [\n        \"For example let's create a board of size 5 X 5 and set the target to 8 instead of 2048.\\n\",\n        \"\\n\",\n        \"**[NOTE]** 2048 originally spawns a (4) 10% of the time! We can disable this for harder games. See [Wikipedia page](https://en.wikipedia.org/wiki/2048_(video_game)) for more details.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"-M8kGaFRJ2ic\",\n        \"outputId\": \"fad6c36b-cb16-490f-ad4f-6bf998dd24ab\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"┌───┬───┬───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;48m  2\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;48m  2\\u001b[0m│\\n\",\n            \"└───┴───┴───┴───┴───┘ ongoing\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"game = GameBoard(size = 5, seed = 42, target = 8, probability_fours = 0.10)\\n\",\n        \"print(game.board().pretty(), game.state())\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"zclUeNxosv4k\",\n        \"outputId\": \"ad099448-d1f2-4471-cbc1-f463293e06ba\"\n      },\n      \"outputs\": [\n        {\n          \"data\": {\n            \"text/plain\": [\n              \"GameBoard(size=5, seed=42, target=8, probability_fours=0.1)\"\n            ]\n          },\n          \"execution_count\": 6,\n          \"metadata\": {},\n          \"output_type\": \"execute_result\"\n        }\n      ],\n      \"source\": [\n        \"game\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"envzrXmjKRff\"\n      },\n      \"source\": [\n        \"We'll use WASD for the action space:\\n\",\n        \"\\n\",\n        \"```\\n\",\n        \"   W\\n\",\n        \"A  S  D\\n\",\n        \"```\\n\",\n        \"Also `game.state()` will say `success` if we succeeded in getting the target!\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"b-gSgthFI_wq\",\n        \"outputId\": \"68af4e66-80c8-4fa0-c7f3-e9ba22923494\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"┌───┬───┬───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;48m  2\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;190m  4\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"└───┴───┴───┴───┴───┘ ongoing\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"game.do_action(\\\"A\\\")\\n\",\n        \"print(game.board().pretty(), game.state())\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"lUDdHKAxvZf8\",\n        \"outputId\": \"38692fcc-bfa9-47b3-82f8-09bee2842d38\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"┌───┬───┬───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;190m  4\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;48m  2\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;48m  2\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"└───┴───┴───┴───┴───┘ ongoing\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"game.do_action(\\\"W\\\")\\n\",\n        \"print(game.board().pretty(), game.state())\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"wkTHxvvUvcmO\",\n        \"outputId\": \"f9447b03-b0eb-443e-e139-607f231c76fe\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"┌───┬───┬───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;190m  4\\u001b[0m│\\u001b[38;5;48m  2\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;48m  2\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;190m  4\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"└───┴───┴───┴───┴───┘ ongoing\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"game.do_action(\\\"D\\\")\\n\",\n        \"print(game.board().pretty(), game.state())\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"XO8vlL-4vd-K\",\n        \"outputId\": \"a6f786bf-39d5-4a23-d79b-17ea9e94272c\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"┌───┬───┬───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;190m  4\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;190m  4\\u001b[0m│\\u001b[38;5;190m  4\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;190m  4\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"└───┴───┴───┴───┴───┘ ongoing\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"game.do_action(\\\"W\\\")\\n\",\n        \"print(game.board().pretty(), game.state())\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"MEa2ngmrvfNm\",\n        \"outputId\": \"c27d9fca-55a0-42c4-dae5-bf8e402d7295\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"┌───┬───┬───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;48m  2\\u001b[0m│\\u001b[38;5;190m  4\\u001b[0m│\\u001b[38;5;196m  8\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;190m  4\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"└───┴───┴───┴───┴───┘ success\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"game.do_action(\\\"D\\\")\\n\",\n        \"print(game.board().pretty(), game.state())\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"gGL1X29Fy4n5\"\n      },\n      \"source\": [\n        \"If we do some other action that's not part of the action space, we will get an error, and the game will not accept anymore actions.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"VZeIHbqoy7yn\",\n        \"outputId\": \"11d15a8f-f09d-4833-8ef7-3bad0510e618\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"┌───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;190m  4\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;48m  2\\u001b[0m│\\n\",\n            \"├───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"└───┴───┴───┘ failed\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"game = GameBoard(size = 3, seed = 42, target = 8, probability_fours = 0.10)\\n\",\n        \"game.do_action(\\\"AA\\\") # Not in WASD\\n\",\n        \"game.do_action(\\\"W\\\")  # Doesn't do anything\\n\",\n        \"game.do_action(\\\"A\\\")  # Doesn't do anything\\n\",\n        \"print(game.board().pretty(), game.state())\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"VR6czU96cpxf\"\n      },\n      \"source\": [\n        \"# RL Environment Setup\\n\",\n        \"\\n\",\n        \"We'll set up a function to accept some strategy that'll emit an action within `WASD` and check the game state.\\n\",\n        \"\\n\",\n        \"We'll also add a timer to only execute the stratgegy for 2 seconds maximum, otherwise it might never terminate!\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"tdgjnf-8z_kr\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from typing import Callable\\n\",\n        \"from unsloth import execute_with_time_limit\\n\",\n        \"\\n\",\n        \"def _execute_strategy(strategy : Callable, game : GameBoard):\\n\",\n        \"    assert callable(strategy)\\n\",\n        \"\\n\",\n        \"    steps = 0\\n\",\n        \"    while game.state() == \\\"ongoing\\\":\\n\",\n        \"        action = strategy(list(game.board()))\\n\",\n        \"        steps += 1\\n\",\n        \"        if type(action) is not str:\\n\",\n        \"            return steps, \\\"failed\\\"\\n\",\n        \"        game.do_action(action)\\n\",\n        \"    return steps, game.state()\\n\",\n        \"\\n\",\n        \"@execute_with_time_limit(2)\\n\",\n        \"def execute_strategy(strategy : Callable, game : GameBoard):\\n\",\n        \"    return _execute_strategy(strategy, game)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ywh0HizI9ayE\"\n      },\n      \"source\": [\n        \"Let's make a generic strategy to just hit `W`. We should expect this generic strategy to fail:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"5bkhqoZc0IO8\",\n        \"outputId\": \"149e18be-dae2-4382-817a-620e7b40ebde\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Timed out with error = Timed out after 2s\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"def always_move_left(board):\\n\",\n        \"    return \\\"W\\\"\\n\",\n        \"\\n\",\n        \"game = GameBoard(size = 8, seed = 42, target = 2048, probability_fours = 0.10)\\n\",\n        \"try:\\n\",\n        \"    execute_strategy(always_move_left, game)\\n\",\n        \"except TimeoutError as e:\\n\",\n        \"    print(f\\\"Timed out with error = {str(e)}\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"dkuHVdB09sgf\"\n      },\n      \"source\": [\n        \"To allow longer strategies for gpt-oss-20b Reinforcement Learning, we shall allow a 5 second timer.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"SK-LfzsA9wbW\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"@execute_with_time_limit(5)\\n\",\n        \"def execute_strategy(strategy : Callable, game : GameBoard):\\n\",\n        \"    return _execute_strategy(strategy, game)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"tRhLV_bZMYxy\"\n      },\n      \"source\": [\n        \"# Code Execution\\n\",\n        \"\\n\",\n        \"To execute and create a new Python function, we first have to check if the function does not call other global variables or cheat. This is called `countering reward hacking` since we don't want the function to cheat.\\n\",\n        \"\\n\",\n        \"For example the below piece of code is fine, since it only imports Python level functions. We use `check_python_modules`:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"zz80kvg6M4BG\",\n        \"outputId\": \"f13fdc0d-ddb3-4c4a-cf65-805dfb31dddd\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Only Python imports? True\\n\",\n            \"{'stdlib': ['math', 'typing'], 'non_stdlib': [], 'relative_imports': 0}\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"from unsloth import check_python_modules\\n\",\n        \"\\n\",\n        \"sample = \\\"\\\"\\\"\\n\",\n        \"def strategy(board):\\n\",\n        \"    import math\\n\",\n        \"    from typing import Callable\\n\",\n        \"    return \\\"W\\\"\\n\",\n        \"\\\"\\\"\\\"\\n\",\n        \"ok, info = check_python_modules(sample)\\n\",\n        \"print(\\\"Only Python imports?\\\", ok)\\n\",\n        \"print(info)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"bZzVWgKQ-VIg\"\n      },\n      \"source\": [\n        \"For the below piece of code, since we import `numpy`, we should not allow the execution:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"Z89Jw1KB-Ux7\",\n        \"outputId\": \"1a4cc701-1677-44b9-d44e-3f3f6dfed8d2\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Only Python imports? False\\n\",\n            \"{'stdlib': [], 'non_stdlib': ['numpy'], 'relative_imports': 0}\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"sample = \\\"\\\"\\\"\\n\",\n        \"def strategy(board):\\n\",\n        \"    from numpy import matmul\\n\",\n        \"    return \\\"W\\\"\\n\",\n        \"\\\"\\\"\\\"\\n\",\n        \"ok, info = check_python_modules(sample)\\n\",\n        \"print(\\\"Only Python imports?\\\", ok)\\n\",\n        \"print(info)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"SDSrjOTLVyQm\"\n      },\n      \"source\": [\n        \"We also disallow global variable access. We'll use Unsloth's `create_locked_down_function` function\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"GcmYAmohVqw2\",\n        \"outputId\": \"bbfcbbb5-8063-42fe-b349-964554317ab8\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"name 'np' is not defined\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"from unsloth import create_locked_down_function\\n\",\n        \"function = \\\"\\\"\\\"\\n\",\n        \"def import_numpy():\\n\",\n        \"    np.matmul\\n\",\n        \"    print(\\\"Success\\\")\\n\",\n        \"\\\"\\\"\\\"\\n\",\n        \"f = create_locked_down_function(function)\\n\",\n        \"try:\\n\",\n        \"    f()\\n\",\n        \"except Exception as e:\\n\",\n        \"    print(str(e))\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"5tJKwLUgZsRq\",\n        \"outputId\": \"13588c11-6685-4627-b2d4-445bff9799c8\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"60\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"from unsloth import create_locked_down_function\\n\",\n        \"function = \\\"\\\"\\\"\\n\",\n        \"def add(a, b):\\n\",\n        \"    def adder(a):\\n\",\n        \"        return a + b\\n\",\n        \"    return adder(b) + b\\n\",\n        \"\\\"\\\"\\\"\\n\",\n        \"f = create_locked_down_function(function)\\n\",\n        \"try:\\n\",\n        \"    print(f(10, 20))\\n\",\n        \"except Exception as e:\\n\",\n        \"    print(str(e))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"8CzwCyXIPK04\"\n      },\n      \"source\": [\n        \"# Data & RL task setup\\n\",\n        \"\\n\",\n        \"We now have to create a prompt to tell the model to create a strategy for the 2048 game. You can customize this to some other task for another RL task.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"B-2RRE4HMrQO\",\n        \"outputId\": \"332255d7-1e6a-4cb4-9ede-c8a2f01378fe\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Create a new short 2048 strategy using only native Python code.\\n\",\n            \"You are given a list of list of numbers for the current board state.\\n\",\n            \"Output one action for \\\"W\\\", \\\"A\\\", \\\"S\\\", \\\"D\\\" on what is the optimal next step.\\n\",\n            \"Output your new short function in backticks using the format below:\\n\",\n            \"```python\\n\",\n            \"def strategy(board):\\n\",\n            \"    return \\\"W\\\" # Example\\n\",\n            \"```\\n\",\n            \"All helper functions should be inside def strategy. Only output the short function `strategy`.\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"prompt = \\\"\\\"\\\"\\n\",\n        \"Create a new short 2048 strategy using only native Python code.\\n\",\n        \"You are given a list of list of numbers for the current board state.\\n\",\n        \"Output one action for \\\"W\\\", \\\"A\\\", \\\"S\\\", \\\"D\\\" on what is the optimal next step.\\n\",\n        \"Output your new short function in backticks using the format below:\\n\",\n        \"```python\\n\",\n        \"def strategy(board):\\n\",\n        \"    return \\\"W\\\" # Example\\n\",\n        \"```\\n\",\n        \"All helper functions should be inside def strategy. Only output the short function `strategy`.\\n\",\n        \"\\\"\\\"\\\".strip()\\n\",\n        \"print(prompt)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"MIdudFUodN4i\"\n      },\n      \"source\": [\n        \"First, let's prompt gpt-oss without RL and see how it goes:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"9HJxrS76h3Ds\",\n        \"outputId\": \"fcfe2220-fc72-4af7-f8a1-f243311b0156\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\\n\",\n            \"Knowledge cutoff: 2024-06\\n\",\n            \"Current date: 2025-10-05\\n\",\n            \"\\n\",\n            \"Reasoning: low\\n\",\n            \"\\n\",\n            \"# Valid channels: analysis, commentary, final. Channel must be included for every message.\\n\",\n            \"Calls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>user<|message|>Create a new short 2048 strategy using only native Python code.\\n\",\n            \"You are given a list of list of numbers for the current board state.\\n\",\n            \"Output one action for \\\"W\\\", \\\"A\\\", \\\"S\\\", \\\"D\\\" on what is the optimal next step.\\n\",\n            \"Output your new short function in backticks using the format below:\\n\",\n            \"```python\\n\",\n            \"def strategy(board):\\n\",\n            \"    return \\\"W\\\" # Example\\n\",\n            \"```\\n\",\n            \"All helper functions should be inside def strategy. Only output the short function `strategy`.<|end|><|start|>assistant<|channel|>analysis<|message|>We need simple heuristic: return random move or based on empty spaces. Provide function.<|end|><|start|>assistant<|channel|>final<|message|>```python\\n\",\n            \"def strategy(board):\\n\",\n            \"    # Count empty cells and choose move that reduces board density\\n\",\n            \"    empty_counts = {\\n\",\n            \"        \\\"W\\\": sum(row.count(0) for row in board[1:]),\\n\",\n            \"        \\\"S\\\": sum(row.count(0) for row in board[:-1]),\\n\",\n            \"        \\\"A\\\": sum(col.count(0) for col in zip(*board)[1:]),\\n\",\n            \"        \\\"D\\\": sum(col.count(0) for col in zip(*board)[:-1]),\\n\",\n            \"    }\\n\",\n            \"    # Prefer the direction with the most empty cells\\n\",\n            \"    return max(empty_counts, key=empty_counts.get)\\n\",\n            \"```<|return|>\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"text = tokenizer.apply_chat_template(\\n\",\n        \"    [{\\\"role\\\": \\\"user\\\", \\\"content\\\": prompt}],\\n\",\n        \"    tokenize = False,\\n\",\n        \"    add_generation_prompt = True,\\n\",\n        \"    reasoning_effort = \\\"low\\\",\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"from transformers import TextStreamer\\n\",\n        \"_ = model.generate(\\n\",\n        \"    **tokenizer(text, return_tensors = \\\"pt\\\").to(\\\"cuda\\\"),\\n\",\n        \"    temperature = 1.0,\\n\",\n        \"    max_new_tokens = 512,\\n\",\n        \"    streamer = TextStreamer(tokenizer, skip_prompt = False),\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"iknaWZNudTNq\"\n      },\n      \"source\": [\n        \"# Reward functions\\n\",\n        \"\\n\",\n        \"We now design a `extract_function` function which simply extracts the function wrapped in 3 back ticks.\\n\",\n        \"\\n\",\n        \"And 3 reward functions:\\n\",\n        \"\\n\",\n        \"1. `function_works` which rewards the model if the strategy is a valid Python function.\\n\",\n        \"2. `no_cheating` which checks if the function imported other modules, and if it did, we penalize it.\\n\",\n        \"3. `strategy_succeeds` which checks if the game strategy actually succeeds in attaining 2048 after running the auto-generated strategy.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"8JJGXKdJ-Zl_\",\n        \"outputId\": \"80fd8078-1621-4c64-a906-5204b444addd\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"def strategy(board):\\n\",\n            \"    return \\\"W\\\" # Example\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"def extract_function(text):\\n\",\n        \"    if text.count(\\\"```\\\") >= 2:\\n\",\n        \"        first = text.find(\\\"```\\\") + 3\\n\",\n        \"        second = text.find(\\\"```\\\", first)\\n\",\n        \"        fx = text[first : second].strip()\\n\",\n        \"        fx = fx.removeprefix(\\\"python\\\\n\\\")\\n\",\n        \"        fx = fx[fx.find(\\\"def\\\"):]\\n\",\n        \"        if fx.startswith(\\\"def strategy(board):\\\"): return fx\\n\",\n        \"    return None\\n\",\n        \"print(extract_function(prompt))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"KLXEcf_HSJlI\"\n      },\n      \"source\": [\n        \"Below is our `function_works` reward function which uses Python's `exec` but guarded by not allowing leakage of local and global variables. We can also use `check_python_modules` first to check if there are errors before even executing the function:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"h3-B0IIsS56S\",\n        \"outputId\": \"f3e174fa-2fbf-400b-ec7d-87590be3ef68\"\n      },\n      \"outputs\": [\n        {\n          \"data\": {\n            \"text/plain\": [\n              \"(False,\\n\",\n              \" {'error': \\\"SyntaxError: expected '(' (<unknown>, line 1)\\\",\\n\",\n              \"  'stdlib': [],\\n\",\n              \"  'non_stdlib': [],\\n\",\n              \"  'relative_imports': 0})\"\n            ]\n          },\n          \"execution_count\": 23,\n          \"metadata\": {},\n          \"output_type\": \"execute_result\"\n        }\n      ],\n      \"source\": [\n        \"ok, info = check_python_modules(\\\"def a\\\")\\n\",\n        \"ok, info\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"qgFNXORy-lpO\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def function_works(completions, **kwargs):\\n\",\n        \"    scores = []\\n\",\n        \"    for completion in completions:\\n\",\n        \"        score = 0\\n\",\n        \"        response = completion[0][\\\"content\\\"]\\n\",\n        \"        function = extract_function(response)\\n\",\n        \"        if function is not None:\\n\",\n        \"            ok, info = check_python_modules(function)\\n\",\n        \"        if function is None or \\\"error\\\" in info:\\n\",\n        \"            score = -2.0\\n\",\n        \"        else:\\n\",\n        \"            try:\\n\",\n        \"                new_strategy = create_locked_down_function(function)\\n\",\n        \"                score = 1.0\\n\",\n        \"            except:\\n\",\n        \"                score = -0.5\\n\",\n        \"        scores.append(score)\\n\",\n        \"    return scores\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Gf69i2WT-m4K\"\n      },\n      \"source\": [\n        \"`no_cheating` checks if the function cheated since it might have imported Numpy or other functions:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"cUfHzCVx-nGK\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def no_cheating(completions, **kwargs):\\n\",\n        \"    scores = []\\n\",\n        \"    for completion in completions:\\n\",\n        \"        score = 0\\n\",\n        \"        response = completion[0][\\\"content\\\"]\\n\",\n        \"        function = extract_function(response)\\n\",\n        \"        if function is not None:\\n\",\n        \"            ok, info = check_python_modules(function)\\n\",\n        \"            scores.append(1.0 if ok else -20.0) # Penalize heavily!\\n\",\n        \"        else:\\n\",\n        \"            scores.append(-1.0) # Failed creating function\\n\",\n        \"    return scores\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"slnqWG3FTror\"\n      },\n      \"source\": [\n        \"Next `strategy_succeeds` checks if the strategy actually allows the game to terminate. Imagine if the strategy simply returned \\\"W\\\" which would fail after a time limit of 10 seconds.\\n\",\n        \"\\n\",\n        \"We also add a global `PRINTER` to print out the strategy and board state.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"sNi129lYTpZ2\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"global PRINTER\\n\",\n        \"PRINTER = 0\\n\",\n        \"def strategy_succeeds(completions, **kwargs):\\n\",\n        \"    global PRINTER\\n\",\n        \"    scores = []\\n\",\n        \"    # Generate a random game board with seed\\n\",\n        \"    seed = np.random.randint(10000)\\n\",\n        \"    for completion in completions:\\n\",\n        \"        printed = False\\n\",\n        \"        score = 0\\n\",\n        \"        response = completion[0][\\\"content\\\"]\\n\",\n        \"        function = extract_function(response)\\n\",\n        \"        if PRINTER % 5 == 0:\\n\",\n        \"            printed = True\\n\",\n        \"            print(function)\\n\",\n        \"        PRINTER += 1\\n\",\n        \"        if function is not None:\\n\",\n        \"            ok, info = check_python_modules(function)\\n\",\n        \"        if function is None or \\\"error\\\" in info:\\n\",\n        \"            scores.append(0)\\n\",\n        \"            continue\\n\",\n        \"        try:\\n\",\n        \"            new_strategy = create_locked_down_function(function)\\n\",\n        \"        except:\\n\",\n        \"            scores.append(0)\\n\",\n        \"            continue\\n\",\n        \"        try:\\n\",\n        \"            game = GameBoard(size = 6, seed = seed, target = 2048, probability_fours = 0.10)\\n\",\n        \"            steps, game_state = execute_strategy(new_strategy, game)\\n\",\n        \"            print(f\\\"Steps = {steps} State = {game_state}\\\")\\n\",\n        \"            if printed is False:\\n\",\n        \"                print(function)\\n\",\n        \"            print(game.board().pretty())\\n\",\n        \"            if game_state == \\\"success\\\":\\n\",\n        \"                scores.append(20.0) # Success - massively reward!\\n\",\n        \"            else:\\n\",\n        \"                scores.append(2.0) # Failed but function works!\\n\",\n        \"        except TimeoutError as e:\\n\",\n        \"            print(\\\"Timeout\\\")\\n\",\n        \"            scores.append(-1.0) # Failed with timeout\\n\",\n        \"        except Exception as e:\\n\",\n        \"            print(f\\\"Exception = {str(e)}\\\")\\n\",\n        \"            scores.append(-3.0) # Failed\\n\",\n        \"    return scores\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"TCpSxtvSeAG_\"\n      },\n      \"source\": [\n        \"We'll now create the dataset which includes a replica of our prompt. Remember to add a reasoning effort of low! You can choose high reasoning mode, but this'll only work on more memory GPUs like H100s.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"Ldf6SjLHVPRv\",\n        \"outputId\": \"589f7523-9835-49b5-c477-4e1d8b0744ff\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"181\\n\"\n          ]\n        },\n        {\n          \"data\": {\n            \"text/plain\": [\n              \"{'prompt': [{'content': 'Create a new short 2048 strategy using only native Python code.\\\\nYou are given a list of list of numbers for the current board state.\\\\nOutput one action for \\\"W\\\", \\\"A\\\", \\\"S\\\", \\\"D\\\" on what is the optimal next step.\\\\nOutput your new short function in backticks using the format below:\\\\n```python\\\\ndef strategy(board):\\\\n    return \\\"W\\\" # Example\\\\n```\\\\nAll helper functions should be inside def strategy. Only output the short function `strategy`.',\\n\",\n              \"   'role': 'user'}],\\n\",\n              \" 'answer': 0,\\n\",\n              \" 'reasoning_effort': 'low'}\"\n            ]\n          },\n          \"execution_count\": 27,\n          \"metadata\": {},\n          \"output_type\": \"execute_result\"\n        }\n      ],\n      \"source\": [\n        \"from datasets import Dataset\\n\",\n        \"dataset = Dataset.from_list([{\\\"prompt\\\" : [{\\\"role\\\": \\\"user\\\", \\\"content\\\": prompt.strip()}], \\\"answer\\\" : 0, \\\"reasoning_effort\\\": \\\"low\\\"}]*1000)\\n\",\n        \"maximum_length = len(tokenizer.apply_chat_template([{\\\"role\\\": \\\"user\\\", \\\"content\\\": prompt.strip()}], add_generation_prompt = True))\\n\",\n        \"print(maximum_length)\\n\",\n        \"dataset[0]\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"9-IOMhVg-2AM\"\n      },\n      \"source\": [\n        \"<a name=\\\"Train\\\"></a>\\n\",\n        \"### Train the model\\n\",\n        \"\\n\",\n        \"Now set up GRPO Trainer and all configurations! We also support GSPO, GAPO, Dr GRPO and more! Go the Unsloth [Reinforcement Learning Docs](https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide) for more options.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"ptqkXK2D4d6p\",\n        \"outputId\": \"2061b833-5b98-4a2b-e7f5-4bc4652d8300\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n\",\n            \"We will change the batch size of 1 to the `num_generations` of 2\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"max_prompt_length = maximum_length + 1 # + 1 just in case!\\n\",\n        \"max_completion_length = max_seq_length - max_prompt_length\\n\",\n        \"\\n\",\n        \"from trl import GRPOConfig, GRPOTrainer\\n\",\n        \"training_args = GRPOConfig(\\n\",\n        \"    temperature = 1.0,\\n\",\n        \"    learning_rate = 5e-5,\\n\",\n        \"    weight_decay = 0.01,\\n\",\n        \"    warmup_ratio = 0.1,\\n\",\n        \"    lr_scheduler_type = \\\"linear\\\",\\n\",\n        \"    optim = \\\"adamw_8bit\\\",\\n\",\n        \"    logging_steps = 1,\\n\",\n        \"    per_device_train_batch_size = 1,\\n\",\n        \"    gradient_accumulation_steps = 1, # Increase to 4 for smoother training\\n\",\n        \"    num_generations = 2, # Decrease if out of memory\\n\",\n        \"    max_prompt_length = max_prompt_length,\\n\",\n        \"    max_completion_length = max_completion_length,\\n\",\n        \"    # num_train_epochs = 1, # Set to 1 for a full training run\\n\",\n        \"    max_steps = 1000,\\n\",\n        \"    save_steps = 100,\\n\",\n        \"    report_to = \\\"none\\\", # Can use Weights & Biases, TrackIO\\n\",\n        \"    output_dir = \\\"outputs\\\",\\n\",\n        \"\\n\",\n        \"    # For optional training + evaluation\\n\",\n        \"    # fp16_full_eval = True,\\n\",\n        \"    # per_device_eval_batch_size = 4,\\n\",\n        \"    # eval_accumulation_steps = 1,\\n\",\n        \"    # eval_strategy = \\\"steps\\\",\\n\",\n        \"    # eval_steps = 1,\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"r9Mv8UZO5hz-\"\n      },\n      \"source\": [\n        \"And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\\n\",\n        \"\\n\",\n        \"You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\\n\",\n        \"\\n\",\n        \"| Step | Training Loss | reward    | reward_std | completion_length | kl       |\\n\",\n        \"|------|---------------|-----------|------------|-------------------|----------|\\n\",\n        \"| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |\\n\",\n        \"| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |\\n\",\n        \"| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"vzOuSVCL_GA9\",\n        \"outputId\": \"349f907c-cc67-4890-e131-397694679634\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Unsloth: Switching to float32 training since model cannot work with float16\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"# For optional training + evaluation\\n\",\n        \"# new_dataset = dataset.train_test_split(test_size = 0.01)\\n\",\n        \"\\n\",\n        \"trainer = GRPOTrainer(\\n\",\n        \"    model = model,\\n\",\n        \"    processing_class = tokenizer,\\n\",\n        \"    reward_funcs = [\\n\",\n        \"        function_works,\\n\",\n        \"        no_cheating,\\n\",\n        \"        strategy_succeeds,\\n\",\n        \"    ],\\n\",\n        \"    args = training_args,\\n\",\n        \"    train_dataset = dataset,\\n\",\n        \"\\n\",\n        \"    # For optional training + evaluation\\n\",\n        \"    # train_dataset = new_dataset[\\\"train\\\"],\\n\",\n        \"    # eval_dataset = new_dataset[\\\"test\\\"],\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fQhtuwP4cf34\"\n      },\n      \"source\": [\n        \"And let's train the model!\\n\",\n        \"\\n\",\n        \"**NOTE** A T4 free GPU might take 5 minutes for one generation sadly since it's an old GPU - A100 or H100 will be much faster!\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 30,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\",\n          \"height\": 1000\n        },\n        \"id\": \"VGRxPdSCcfC3\",\n        \"outputId\": \"f8bb720c-6d69-4f43-d9d1-a404842d2dff\"\n      },\n      \"outputs\": [\n        {\n          \"metadata\": {\n            \"tags\": null\n          },\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 199998, 'pad_token_id': 200017}.\\n\",\n            \"==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 2\\n\",\n            \"   \\\\\\\\   /|    Num examples = 1,000 | Num Epochs = 1 | Total steps = 1,000\\n\",\n            \"O^O/ \\\\_/ \\\\    Batch size per device = 2 | Gradient accumulation steps = 1\\n\",\n            \"\\\\        /    Data Parallel GPUs = 1 | Total batch size (2 x 1 x 1) = 2\\n\",\n            \" \\\"-____-\\\"     Trainable parameters = 1,990,656 of 20,916,747,840 (0.01% trained)\\n\",\n            \"`generation_config` default values have been modified to match model-specific defaults: {'max_length': 131072}. If this is not desired, please set these values explicitly.\\n\"\n          ]\n        },\n        {\n          \"metadata\": {\n            \"tags\": null\n          },\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"None\\n\",\n            \"Steps = 1 State = failed\\n\",\n            \"def strategy(board):\\n\",\n            \"    # simple heuristic: prefer right or down, then left, then up\\n\",\n            \"    for move in \\\"R D L U\\\".split():\\n\",\n            \"        pass\\n\",\n            \"┌───┬───┬───┬───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;45m  2\\u001b[0m│\\u001b[38;5;45m  2\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"└───┴───┴───┴───┴───┴───┘\\n\"\n          ]\n        },\n        {\n          \"data\": {\n            \"text/html\": [\n              \"\\n\",\n              \"    <div>\\n\",\n              \"      \\n\",\n              \"      <progress value='86' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\\n\",\n              \"      [  86/1000 8:06:01 < 88:08:29, 0.00 it/s, Epoch 0.09/1]\\n\",\n              \"    </div>\\n\",\n              \"    <table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n              \"  <thead>\\n\",\n              \" <tr style=\\\"text-align: left;\\\">\\n\",\n              \"      <th>Step</th>\\n\",\n              \"      <th>Training Loss</th>\\n\",\n              \"      <th>reward</th>\\n\",\n              \"      <th>reward_std</th>\\n\",\n              \"      <th>completions / mean_length</th>\\n\",\n              \"      <th>completions / min_length</th>\\n\",\n              \"      <th>completions / max_length</th>\\n\",\n              \"      <th>completions / clipped_ratio</th>\\n\",\n              \"      <th>completions / mean_terminated_length</th>\\n\",\n              \"      <th>completions / min_terminated_length</th>\\n\",\n              \"      <th>completions / max_terminated_length</th>\\n\",\n              \"      <th>kl</th>\\n\",\n              \"      <th>rewards / function_works / mean</th>\\n\",\n              \"      <th>rewards / function_works / std</th>\\n\",\n              \"      <th>rewards / no_cheating / mean</th>\\n\",\n              \"      <th>rewards / no_cheating / std</th>\\n\",\n              \"      <th>rewards / strategy_succeeds / mean</th>\\n\",\n              \"      <th>rewards / strategy_succeeds / std</th>\\n\",\n              \"    </tr>\\n\",\n              \"  </thead>\\n\",\n              \"  <tbody>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>1</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>4.949748</td>\\n\",\n              \"      <td>329.000000</td>\\n\",\n              \"      <td>72.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>72.000000</td>\\n\",\n              \"      <td>72.000000</td>\\n\",\n              \"      <td>72.000000</td>\\n\",\n              \"      <td>0.002197</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>2</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>4.949748</td>\\n\",\n              \"      <td>550.500000</td>\\n\",\n              \"      <td>515.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>515.000000</td>\\n\",\n              \"      <td>515.000000</td>\\n\",\n              \"      <td>515.000000</td>\\n\",\n              \"      <td>0.000298</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>3</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>538.000000</td>\\n\",\n              \"      <td>490.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>490.000000</td>\\n\",\n              \"      <td>490.000000</td>\\n\",\n              \"      <td>490.000000</td>\\n\",\n              \"      <td>0.000276</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>4</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>2.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>325.000000</td>\\n\",\n              \"      <td>120.000000</td>\\n\",\n              \"      <td>530.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>325.000000</td>\\n\",\n              \"      <td>120.000000</td>\\n\",\n              \"      <td>530.000000</td>\\n\",\n              \"      <td>0.000568</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>5</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>437.000000</td>\\n\",\n              \"      <td>288.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>288.000000</td>\\n\",\n              \"      <td>288.000000</td>\\n\",\n              \"      <td>288.000000</td>\\n\",\n              \"      <td>0.001381</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>6</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>308.500000</td>\\n\",\n              \"      <td>301.000000</td>\\n\",\n              \"      <td>316.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>308.500000</td>\\n\",\n              \"      <td>301.000000</td>\\n\",\n              \"      <td>316.000000</td>\\n\",\n              \"      <td>0.000826</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-3.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>7</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>519.000000</td>\\n\",\n              \"      <td>452.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>452.000000</td>\\n\",\n              \"      <td>452.000000</td>\\n\",\n              \"      <td>452.000000</td>\\n\",\n              \"      <td>0.000223</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>8</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>333.500000</td>\\n\",\n              \"      <td>81.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>81.000000</td>\\n\",\n              \"      <td>81.000000</td>\\n\",\n              \"      <td>81.000000</td>\\n\",\n              \"      <td>0.001181</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>9</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>568.500000</td>\\n\",\n              \"      <td>551.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>551.000000</td>\\n\",\n              \"      <td>551.000000</td>\\n\",\n              \"      <td>551.000000</td>\\n\",\n              \"      <td>0.000281</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>10</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-3.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000153</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>11</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>2.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>330.000000</td>\\n\",\n              \"      <td>264.000000</td>\\n\",\n              \"      <td>396.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>330.000000</td>\\n\",\n              \"      <td>264.000000</td>\\n\",\n              \"      <td>396.000000</td>\\n\",\n              \"      <td>0.004015</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>12</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>374.500000</td>\\n\",\n              \"      <td>360.000000</td>\\n\",\n              \"      <td>389.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>374.500000</td>\\n\",\n              \"      <td>360.000000</td>\\n\",\n              \"      <td>389.000000</td>\\n\",\n              \"      <td>0.000245</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>13</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>520.500000</td>\\n\",\n              \"      <td>455.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>455.000000</td>\\n\",\n              \"      <td>455.000000</td>\\n\",\n              \"      <td>455.000000</td>\\n\",\n              \"      <td>0.000915</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>14</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>406.500000</td>\\n\",\n              \"      <td>227.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>227.000000</td>\\n\",\n              \"      <td>227.000000</td>\\n\",\n              \"      <td>227.000000</td>\\n\",\n              \"      <td>0.007664</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>15</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>348.500000</td>\\n\",\n              \"      <td>302.000000</td>\\n\",\n              \"      <td>395.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>348.500000</td>\\n\",\n              \"      <td>302.000000</td>\\n\",\n              \"      <td>395.000000</td>\\n\",\n              \"      <td>0.002411</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>16</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>408.000000</td>\\n\",\n              \"      <td>379.000000</td>\\n\",\n              \"      <td>437.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>408.000000</td>\\n\",\n              \"      <td>379.000000</td>\\n\",\n              \"      <td>437.000000</td>\\n\",\n              \"      <td>0.002496</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>17</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-12.500000</td>\\n\",\n              \"      <td>13.435029</td>\\n\",\n              \"      <td>493.000000</td>\\n\",\n              \"      <td>400.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>400.000000</td>\\n\",\n              \"      <td>400.000000</td>\\n\",\n              \"      <td>400.000000</td>\\n\",\n              \"      <td>0.009901</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-10.500000</td>\\n\",\n              \"      <td>13.435029</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>18</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>413.000000</td>\\n\",\n              \"      <td>260.000000</td>\\n\",\n              \"      <td>566.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>413.000000</td>\\n\",\n              \"      <td>260.000000</td>\\n\",\n              \"      <td>566.000000</td>\\n\",\n              \"      <td>0.021275</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>19</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>487.500000</td>\\n\",\n              \"      <td>389.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>389.000000</td>\\n\",\n              \"      <td>389.000000</td>\\n\",\n              \"      <td>389.000000</td>\\n\",\n              \"      <td>0.019204</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>20</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.001022</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>21</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>397.500000</td>\\n\",\n              \"      <td>276.000000</td>\\n\",\n              \"      <td>519.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>397.500000</td>\\n\",\n              \"      <td>276.000000</td>\\n\",\n              \"      <td>519.000000</td>\\n\",\n              \"      <td>0.027686</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>22</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>499.500000</td>\\n\",\n              \"      <td>486.000000</td>\\n\",\n              \"      <td>513.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>499.500000</td>\\n\",\n              \"      <td>486.000000</td>\\n\",\n              \"      <td>513.000000</td>\\n\",\n              \"      <td>0.007218</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>23</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.250000</td>\\n\",\n              \"      <td>2.474874</td>\\n\",\n              \"      <td>575.500000</td>\\n\",\n              \"      <td>565.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>565.000000</td>\\n\",\n              \"      <td>565.000000</td>\\n\",\n              \"      <td>565.000000</td>\\n\",\n              \"      <td>0.005928</td>\\n\",\n              \"      <td>-1.250000</td>\\n\",\n              \"      <td>1.060660</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>24</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>563.500000</td>\\n\",\n              \"      <td>541.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>541.000000</td>\\n\",\n              \"      <td>541.000000</td>\\n\",\n              \"      <td>541.000000</td>\\n\",\n              \"      <td>0.008769</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>25</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>444.500000</td>\\n\",\n              \"      <td>303.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>303.000000</td>\\n\",\n              \"      <td>303.000000</td>\\n\",\n              \"      <td>303.000000</td>\\n\",\n              \"      <td>0.084963</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>26</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>419.000000</td>\\n\",\n              \"      <td>252.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>252.000000</td>\\n\",\n              \"      <td>252.000000</td>\\n\",\n              \"      <td>252.000000</td>\\n\",\n              \"      <td>0.114125</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>27</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>339.500000</td>\\n\",\n              \"      <td>321.000000</td>\\n\",\n              \"      <td>358.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>339.500000</td>\\n\",\n              \"      <td>321.000000</td>\\n\",\n              \"      <td>358.000000</td>\\n\",\n              \"      <td>0.033457</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>28</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>372.500000</td>\\n\",\n              \"      <td>311.000000</td>\\n\",\n              \"      <td>434.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>372.500000</td>\\n\",\n              \"      <td>311.000000</td>\\n\",\n              \"      <td>434.000000</td>\\n\",\n              \"      <td>0.081829</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>29</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>387.500000</td>\\n\",\n              \"      <td>336.000000</td>\\n\",\n              \"      <td>439.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>387.500000</td>\\n\",\n              \"      <td>336.000000</td>\\n\",\n              \"      <td>439.000000</td>\\n\",\n              \"      <td>0.100017</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>30</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>463.000000</td>\\n\",\n              \"      <td>410.000000</td>\\n\",\n              \"      <td>516.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>463.000000</td>\\n\",\n              \"      <td>410.000000</td>\\n\",\n              \"      <td>516.000000</td>\\n\",\n              \"      <td>0.095180</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>31</td>\\n\",\n              \"      <td>0.000300</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>445.500000</td>\\n\",\n              \"      <td>305.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>305.000000</td>\\n\",\n              \"      <td>305.000000</td>\\n\",\n              \"      <td>305.000000</td>\\n\",\n              \"      <td>0.321803</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>32</td>\\n\",\n              \"      <td>0.000300</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>425.000000</td>\\n\",\n              \"      <td>310.000000</td>\\n\",\n              \"      <td>540.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>425.000000</td>\\n\",\n              \"      <td>310.000000</td>\\n\",\n              \"      <td>540.000000</td>\\n\",\n              \"      <td>0.335011</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>33</td>\\n\",\n              \"      <td>0.000400</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>458.500000</td>\\n\",\n              \"      <td>331.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>331.000000</td>\\n\",\n              \"      <td>331.000000</td>\\n\",\n              \"      <td>331.000000</td>\\n\",\n              \"      <td>0.362238</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>34</td>\\n\",\n              \"      <td>0.000500</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>347.500000</td>\\n\",\n              \"      <td>207.000000</td>\\n\",\n              \"      <td>488.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>347.500000</td>\\n\",\n              \"      <td>207.000000</td>\\n\",\n              \"      <td>488.000000</td>\\n\",\n              \"      <td>0.518291</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>35</td>\\n\",\n              \"      <td>0.000400</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>471.000000</td>\\n\",\n              \"      <td>356.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>356.000000</td>\\n\",\n              \"      <td>356.000000</td>\\n\",\n              \"      <td>356.000000</td>\\n\",\n              \"      <td>0.383606</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>36</td>\\n\",\n              \"      <td>0.000700</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>393.000000</td>\\n\",\n              \"      <td>200.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>200.000000</td>\\n\",\n              \"      <td>200.000000</td>\\n\",\n              \"      <td>200.000000</td>\\n\",\n              \"      <td>0.674902</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>37</td>\\n\",\n              \"      <td>0.000700</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>344.500000</td>\\n\",\n              \"      <td>198.000000</td>\\n\",\n              \"      <td>491.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>344.500000</td>\\n\",\n              \"      <td>198.000000</td>\\n\",\n              \"      <td>491.000000</td>\\n\",\n              \"      <td>0.689294</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>38</td>\\n\",\n              \"      <td>0.000600</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>473.500000</td>\\n\",\n              \"      <td>361.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>361.000000</td>\\n\",\n              \"      <td>361.000000</td>\\n\",\n              \"      <td>361.000000</td>\\n\",\n              \"      <td>0.607979</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>39</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>380.000000</td>\\n\",\n              \"      <td>361.000000</td>\\n\",\n              \"      <td>399.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>380.000000</td>\\n\",\n              \"      <td>361.000000</td>\\n\",\n              \"      <td>399.000000</td>\\n\",\n              \"      <td>0.142165</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>40</td>\\n\",\n              \"      <td>0.000300</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>386.500000</td>\\n\",\n              \"      <td>352.000000</td>\\n\",\n              \"      <td>421.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>386.500000</td>\\n\",\n              \"      <td>352.000000</td>\\n\",\n              \"      <td>421.000000</td>\\n\",\n              \"      <td>0.293521</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>41</td>\\n\",\n              \"      <td>0.000500</td>\\n\",\n              \"      <td>-10.500000</td>\\n\",\n              \"      <td>16.263456</td>\\n\",\n              \"      <td>107.500000</td>\\n\",\n              \"      <td>89.000000</td>\\n\",\n              \"      <td>126.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>107.500000</td>\\n\",\n              \"      <td>89.000000</td>\\n\",\n              \"      <td>126.000000</td>\\n\",\n              \"      <td>0.465591</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>-9.500000</td>\\n\",\n              \"      <td>14.849242</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>42</td>\\n\",\n              \"      <td>0.000300</td>\\n\",\n              \"      <td>-0.250000</td>\\n\",\n              \"      <td>1.060660</td>\\n\",\n              \"      <td>410.000000</td>\\n\",\n              \"      <td>373.000000</td>\\n\",\n              \"      <td>447.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>410.000000</td>\\n\",\n              \"      <td>373.000000</td>\\n\",\n              \"      <td>447.000000</td>\\n\",\n              \"      <td>0.314028</td>\\n\",\n              \"      <td>0.250000</td>\\n\",\n              \"      <td>1.060660</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>43</td>\\n\",\n              \"      <td>0.000800</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>473.000000</td>\\n\",\n              \"      <td>360.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>360.000000</td>\\n\",\n              \"      <td>360.000000</td>\\n\",\n              \"      <td>360.000000</td>\\n\",\n              \"      <td>0.753577</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>44</td>\\n\",\n              \"      <td>0.000400</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>528.500000</td>\\n\",\n              \"      <td>471.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>471.000000</td>\\n\",\n              \"      <td>471.000000</td>\\n\",\n              \"      <td>471.000000</td>\\n\",\n              \"      <td>0.370155</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>45</td>\\n\",\n              \"      <td>0.000600</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>360.000000</td>\\n\",\n              \"      <td>293.000000</td>\\n\",\n              \"      <td>427.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>360.000000</td>\\n\",\n              \"      <td>293.000000</td>\\n\",\n              \"      <td>427.000000</td>\\n\",\n              \"      <td>0.609444</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>46</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>581.500000</td>\\n\",\n              \"      <td>577.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>577.000000</td>\\n\",\n              \"      <td>577.000000</td>\\n\",\n              \"      <td>577.000000</td>\\n\",\n              \"      <td>0.021817</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>47</td>\\n\",\n              \"      <td>0.000900</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>466.500000</td>\\n\",\n              \"      <td>347.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>347.000000</td>\\n\",\n              \"      <td>347.000000</td>\\n\",\n              \"      <td>347.000000</td>\\n\",\n              \"      <td>0.863071</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>48</td>\\n\",\n              \"      <td>0.000700</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>495.000000</td>\\n\",\n              \"      <td>404.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>404.000000</td>\\n\",\n              \"      <td>404.000000</td>\\n\",\n              \"      <td>404.000000</td>\\n\",\n              \"      <td>0.727124</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>49</td>\\n\",\n              \"      <td>0.000200</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>558.500000</td>\\n\",\n              \"      <td>531.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>531.000000</td>\\n\",\n              \"      <td>531.000000</td>\\n\",\n              \"      <td>531.000000</td>\\n\",\n              \"      <td>0.173142</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>50</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>477.000000</td>\\n\",\n              \"      <td>465.000000</td>\\n\",\n              \"      <td>489.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>477.000000</td>\\n\",\n              \"      <td>465.000000</td>\\n\",\n              \"      <td>489.000000</td>\\n\",\n              \"      <td>0.089374</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>51</td>\\n\",\n              \"      <td>0.001400</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>367.500000</td>\\n\",\n              \"      <td>149.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>149.000000</td>\\n\",\n              \"      <td>149.000000</td>\\n\",\n              \"      <td>149.000000</td>\\n\",\n              \"      <td>1.374907</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>52</td>\\n\",\n              \"      <td>0.000900</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>458.500000</td>\\n\",\n              \"      <td>331.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>331.000000</td>\\n\",\n              \"      <td>331.000000</td>\\n\",\n              \"      <td>331.000000</td>\\n\",\n              \"      <td>0.929248</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>53</td>\\n\",\n              \"      <td>0.000900</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>475.000000</td>\\n\",\n              \"      <td>364.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>364.000000</td>\\n\",\n              \"      <td>364.000000</td>\\n\",\n              \"      <td>364.000000</td>\\n\",\n              \"      <td>0.887930</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>54</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>439.000000</td>\\n\",\n              \"      <td>424.000000</td>\\n\",\n              \"      <td>454.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>439.000000</td>\\n\",\n              \"      <td>424.000000</td>\\n\",\n              \"      <td>454.000000</td>\\n\",\n              \"      <td>0.126352</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>55</td>\\n\",\n              \"      <td>0.000400</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>323.500000</td>\\n\",\n              \"      <td>293.000000</td>\\n\",\n              \"      <td>354.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>323.500000</td>\\n\",\n              \"      <td>293.000000</td>\\n\",\n              \"      <td>354.000000</td>\\n\",\n              \"      <td>0.367167</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>56</td>\\n\",\n              \"      <td>0.000400</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>543.000000</td>\\n\",\n              \"      <td>500.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>543.000000</td>\\n\",\n              \"      <td>500.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.375893</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>57</td>\\n\",\n              \"      <td>0.000700</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>382.000000</td>\\n\",\n              \"      <td>317.000000</td>\\n\",\n              \"      <td>447.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>382.000000</td>\\n\",\n              \"      <td>317.000000</td>\\n\",\n              \"      <td>447.000000</td>\\n\",\n              \"      <td>0.687571</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>58</td>\\n\",\n              \"      <td>0.000600</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>506.000000</td>\\n\",\n              \"      <td>426.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>426.000000</td>\\n\",\n              \"      <td>426.000000</td>\\n\",\n              \"      <td>426.000000</td>\\n\",\n              \"      <td>0.648271</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>59</td>\\n\",\n              \"      <td>0.001100</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>260.500000</td>\\n\",\n              \"      <td>187.000000</td>\\n\",\n              \"      <td>334.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>260.500000</td>\\n\",\n              \"      <td>187.000000</td>\\n\",\n              \"      <td>334.000000</td>\\n\",\n              \"      <td>1.084255</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>60</td>\\n\",\n              \"      <td>0.000200</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>523.500000</td>\\n\",\n              \"      <td>495.000000</td>\\n\",\n              \"      <td>552.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>523.500000</td>\\n\",\n              \"      <td>495.000000</td>\\n\",\n              \"      <td>552.000000</td>\\n\",\n              \"      <td>0.198019</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>61</td>\\n\",\n              \"      <td>0.001000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>471.500000</td>\\n\",\n              \"      <td>357.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>357.000000</td>\\n\",\n              \"      <td>357.000000</td>\\n\",\n              \"      <td>357.000000</td>\\n\",\n              \"      <td>0.987108</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>62</td>\\n\",\n              \"      <td>0.000400</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>532.000000</td>\\n\",\n              \"      <td>478.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>478.000000</td>\\n\",\n              \"      <td>478.000000</td>\\n\",\n              \"      <td>478.000000</td>\\n\",\n              \"      <td>0.428900</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>63</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>411.000000</td>\\n\",\n              \"      <td>400.000000</td>\\n\",\n              \"      <td>422.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>411.000000</td>\\n\",\n              \"      <td>400.000000</td>\\n\",\n              \"      <td>422.000000</td>\\n\",\n              \"      <td>0.107686</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-3.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>64</td>\\n\",\n              \"      <td>0.001000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>470.500000</td>\\n\",\n              \"      <td>355.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>355.000000</td>\\n\",\n              \"      <td>355.000000</td>\\n\",\n              \"      <td>355.000000</td>\\n\",\n              \"      <td>0.967091</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>65</td>\\n\",\n              \"      <td>0.000300</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>553.000000</td>\\n\",\n              \"      <td>520.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>520.000000</td>\\n\",\n              \"      <td>520.000000</td>\\n\",\n              \"      <td>520.000000</td>\\n\",\n              \"      <td>0.262037</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>66</td>\\n\",\n              \"      <td>0.000400</td>\\n\",\n              \"      <td>2.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>471.500000</td>\\n\",\n              \"      <td>423.000000</td>\\n\",\n              \"      <td>520.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>471.500000</td>\\n\",\n              \"      <td>423.000000</td>\\n\",\n              \"      <td>520.000000</td>\\n\",\n              \"      <td>0.414690</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>67</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>580.500000</td>\\n\",\n              \"      <td>575.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>575.000000</td>\\n\",\n              \"      <td>575.000000</td>\\n\",\n              \"      <td>575.000000</td>\\n\",\n              \"      <td>0.035250</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>68</td>\\n\",\n              \"      <td>0.001200</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>435.000000</td>\\n\",\n              \"      <td>284.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>284.000000</td>\\n\",\n              \"      <td>284.000000</td>\\n\",\n              \"      <td>284.000000</td>\\n\",\n              \"      <td>1.168353</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>69</td>\\n\",\n              \"      <td>0.000800</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>492.000000</td>\\n\",\n              \"      <td>398.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>398.000000</td>\\n\",\n              \"      <td>398.000000</td>\\n\",\n              \"      <td>398.000000</td>\\n\",\n              \"      <td>0.789415</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>70</td>\\n\",\n              \"      <td>0.000700</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>291.500000</td>\\n\",\n              \"      <td>240.000000</td>\\n\",\n              \"      <td>343.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>291.500000</td>\\n\",\n              \"      <td>240.000000</td>\\n\",\n              \"      <td>343.000000</td>\\n\",\n              \"      <td>0.723002</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>71</td>\\n\",\n              \"      <td>0.001000</td>\\n\",\n              \"      <td>-10.500000</td>\\n\",\n              \"      <td>16.263456</td>\\n\",\n              \"      <td>407.000000</td>\\n\",\n              \"      <td>301.000000</td>\\n\",\n              \"      <td>513.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>407.000000</td>\\n\",\n              \"      <td>301.000000</td>\\n\",\n              \"      <td>513.000000</td>\\n\",\n              \"      <td>0.958203</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>-9.500000</td>\\n\",\n              \"      <td>14.849242</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>72</td>\\n\",\n              \"      <td>0.000900</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>362.500000</td>\\n\",\n              \"      <td>279.000000</td>\\n\",\n              \"      <td>446.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>362.500000</td>\\n\",\n              \"      <td>279.000000</td>\\n\",\n              \"      <td>446.000000</td>\\n\",\n              \"      <td>0.902191</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>73</td>\\n\",\n              \"      <td>0.000100</td>\\n\",\n              \"      <td>0.750000</td>\\n\",\n              \"      <td>0.353553</td>\\n\",\n              \"      <td>479.000000</td>\\n\",\n              \"      <td>466.000000</td>\\n\",\n              \"      <td>492.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>479.000000</td>\\n\",\n              \"      <td>466.000000</td>\\n\",\n              \"      <td>492.000000</td>\\n\",\n              \"      <td>0.102604</td>\\n\",\n              \"      <td>0.250000</td>\\n\",\n              \"      <td>1.060660</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>74</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>579.000000</td>\\n\",\n              \"      <td>572.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>572.000000</td>\\n\",\n              \"      <td>572.000000</td>\\n\",\n              \"      <td>572.000000</td>\\n\",\n              \"      <td>0.049443</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-1.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>75</td>\\n\",\n              \"      <td>0.000200</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>530.500000</td>\\n\",\n              \"      <td>507.000000</td>\\n\",\n              \"      <td>554.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>530.500000</td>\\n\",\n              \"      <td>507.000000</td>\\n\",\n              \"      <td>554.000000</td>\\n\",\n              \"      <td>0.173276</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>76</td>\\n\",\n              \"      <td>0.000500</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>401.000000</td>\\n\",\n              \"      <td>353.000000</td>\\n\",\n              \"      <td>449.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>401.000000</td>\\n\",\n              \"      <td>353.000000</td>\\n\",\n              \"      <td>449.000000</td>\\n\",\n              \"      <td>0.522857</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>77</td>\\n\",\n              \"      <td>0.000300</td>\\n\",\n              \"      <td>0.750000</td>\\n\",\n              \"      <td>0.353553</td>\\n\",\n              \"      <td>512.500000</td>\\n\",\n              \"      <td>473.000000</td>\\n\",\n              \"      <td>552.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>512.500000</td>\\n\",\n              \"      <td>473.000000</td>\\n\",\n              \"      <td>552.000000</td>\\n\",\n              \"      <td>0.271977</td>\\n\",\n              \"      <td>0.250000</td>\\n\",\n              \"      <td>1.060660</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>78</td>\\n\",\n              \"      <td>0.000200</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>432.500000</td>\\n\",\n              \"      <td>411.000000</td>\\n\",\n              \"      <td>454.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>432.500000</td>\\n\",\n              \"      <td>411.000000</td>\\n\",\n              \"      <td>454.000000</td>\\n\",\n              \"      <td>0.181327</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>79</td>\\n\",\n              \"      <td>0.000200</td>\\n\",\n              \"      <td>10.500000</td>\\n\",\n              \"      <td>16.263456</td>\\n\",\n              \"      <td>475.000000</td>\\n\",\n              \"      <td>452.000000</td>\\n\",\n              \"      <td>498.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>475.000000</td>\\n\",\n              \"      <td>452.000000</td>\\n\",\n              \"      <td>498.000000</td>\\n\",\n              \"      <td>0.200004</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>8.500000</td>\\n\",\n              \"      <td>16.263456</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>80</td>\\n\",\n              \"      <td>0.000600</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>341.000000</td>\\n\",\n              \"      <td>296.000000</td>\\n\",\n              \"      <td>386.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>341.000000</td>\\n\",\n              \"      <td>296.000000</td>\\n\",\n              \"      <td>386.000000</td>\\n\",\n              \"      <td>0.606937</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-2.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>81</td>\\n\",\n              \"      <td>0.000200</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>456.500000</td>\\n\",\n              \"      <td>428.000000</td>\\n\",\n              \"      <td>485.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>456.500000</td>\\n\",\n              \"      <td>428.000000</td>\\n\",\n              \"      <td>485.000000</td>\\n\",\n              \"      <td>0.235978</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>82</td>\\n\",\n              \"      <td>0.000800</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>407.000000</td>\\n\",\n              \"      <td>326.000000</td>\\n\",\n              \"      <td>488.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>407.000000</td>\\n\",\n              \"      <td>326.000000</td>\\n\",\n              \"      <td>488.000000</td>\\n\",\n              \"      <td>0.825952</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>83</td>\\n\",\n              \"      <td>0.000200</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>557.500000</td>\\n\",\n              \"      <td>529.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>529.000000</td>\\n\",\n              \"      <td>529.000000</td>\\n\",\n              \"      <td>529.000000</td>\\n\",\n              \"      <td>0.239547</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"    <tr>\\n\",\n              \"      <td>84</td>\\n\",\n              \"      <td>0.001600</td>\\n\",\n              \"      <td>-1.000000</td>\\n\",\n              \"      <td>2.828427</td>\\n\",\n              \"      <td>368.500000</td>\\n\",\n              \"      <td>151.000000</td>\\n\",\n              \"      <td>586.000000</td>\\n\",\n              \"      <td>0.500000</td>\\n\",\n              \"      <td>151.000000</td>\\n\",\n              \"      <td>151.000000</td>\\n\",\n              \"      <td>151.000000</td>\\n\",\n              \"      <td>1.608883</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>2.121320</td>\\n\",\n              \"      <td>0.000000</td>\\n\",\n              \"      <td>1.414214</td>\\n\",\n              \"      <td>-0.500000</td>\\n\",\n              \"      <td>0.707107</td>\\n\",\n              \"    </tr>\\n\",\n              \"  </tbody>\\n\",\n              \"</table><p>\"\n            ],\n            \"text/plain\": [\n              \"<IPython.core.display.HTML object>\"\n            ]\n          },\n          \"metadata\": {},\n          \"output_type\": \"display_data\"\n        },\n        {\n          \"metadata\": {\n            \"tags\": null\n          },\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Steps = 1 State = failed\\n\",\n            \"def strategy(board):\\n\",\n            \"    # Helper: simulate a move, return new board and score\\n\",\n            \"    def simulate(board, dir):\\n\",\n            \"        n = len(board)\\n\",\n            \"        new = [[0]*n for _ in range(n)]\\n\",\n            \"        score = 0\\n\",\n            \"        for i in range(n):\\n\",\n            \"            # extract line\\n\",\n            \"            if dir == 'A':\\n\",\n            \"                line = [board[i][j] for j in range(n)]\\n\",\n            \"                rev = False\\n\",\n            \"            elif dir == 'D':\\n\",\n            \"                line = [board[i][j] for j in range(n-1, -1, -1)]\\n\",\n            \"                rev = True\\n\",\n            \"            elif dir == 'W':\\n\",\n            \"                line = [board[j][i] for j in range(n)]\\n\",\n            \"                rev = False\\n\",\n            \"            else:  # 'S'\\n\",\n            \"                line = [board[j][i] for j in range(n-1, -1, -1)]\\n\",\n            \"                rev = True\\n\",\n            \"            # compress and merge\\n\",\n            \"            new_line = [x for x in line if x != 0]\\n\",\n            \"            merged = []\\n\",\n            \"            j = 0\\n\",\n            \"            while j < len(new_line):\\n\",\n            \"                if j + 1 < len(new_line) and new_line[j] == new_line[j+1]:\\n\",\n            \"                    merged.append(new_line[j]*2)\\n\",\n            \"                    score += new_line[j]*2\\n\",\n            \"                    j += 2\\n\",\n            \"                else:\\n\",\n            \"                    merged.append(new_line[j])\\n\",\n            \"                    j += 1\\n\",\n            \"            # fill with zeros\\n\",\n            \"            merged += [0]*(n-len(merged))\\n\",\n            \"            # place back\\n\",\n            \"            if rev:\\n\",\n            \"                merged = merged[::-1]\\n\",\n            \"            if dir in ('A','D'):\\n\",\n            \"                for j in range(n):\\n\",\n            \"                    new[i][j] = merged[j]\\n\",\n            \"            else:\\n\",\n            \"                for j in range(n):\\n\",\n            \"                    new[j][i] = merged[j]\\n\",\n            \"        return new, score\\n\",\n            \"\\n\",\n            \"    best, best_dir = 0, None\\n\",\n            \"    for dir in ('W','A','S','D'):\\n\",\n            \"        _, score = simulate(board, dir)\\n\",\n            \"        if score > best:\\n\",\n            \"            best, best_dir = score, dir\\n\",\n            \"    return best_dir  # returns one of 'W','A','S','D'\\n\",\n            \"┌───┬───┬───┬───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;51m  4\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;45m  2\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\u001b[38;5;239m  .\\u001b[0m│\\n\",\n            \"└───┴───┴───┴───┴───┴───┘\\n\",\n            \"Unsloth: Will smartly offload gradients to save VRAM!\\n\",\n            \"def strategy(board):\\n\",\n            \"    # helpers\\n\",\n            \"    def move(b, d):\\n\",\n            \"        n = len(b)\\n\",\n            \"        def compress(row):\\n\",\n            \"            new = [x for x in row if x!=0]\\n\",\n            \"            for i in range(len(new)-1):\\n\",\n            \"                if new[i]==new[i+1]:\\n\",\n            \"                    new[i]*=2; new[i+1]=0\\n\",\n            \"            return [x for x in new if x!=0]+[0]*(n-len(new))\\n\",\n            \"        res=[[0]*n for _ in range(n)]\\n\",\n            \"        if d==\\\"W\\\":\\n\",\n            \"            for j in range(n):\\n\",\n            \"                col=[b[i][j] for i in range(n)]\\n\",\n            \"                col=compress(col)\\n\",\n            \"                for i in range(n):\\n\",\n            \"                    res[i][j]=col[i]\\n\",\n            \"        elif d==\\\"S\\\":\\n\",\n            \"            for j in range(n):\\n\",\n            \"                col=[b[i][j] for i in range(n)][::-1]\\n\",\n            \"                col=compress(col)\\n\",\n            \"                col=col[::-1]\\n\",\n            \"                for i in range(n):\\n\",\n            \"                    res[i][j]=col[i]\\n\",\n            \"        elif d==\\\"A\\\":\\n\",\n            \"            for i in range(n):\\n\",\n            \"                row=compress(b[i])\\n\",\n            \"                res[i]=row\\n\",\n            \"        elif d==\\\"D\\\":\\n\",\n            \"            for i in range(n):\\n\",\n            \"                row=compress(b[i][::-1])\\n\",\n            \"                row=row[::-1]\\n\",\n            \"                res[i]=row\\n\",\n            \"        return res\\n\",\n            \"\\n\",\n            \"    def score(b):\\n\",\n            \"        return sum(sum(row) for row in b)\\n\",\n            \"\\n\",\n            \"    moves=\\\"WASD\\\"\\n\",\n            \"    best=None; best_val=-1\\n\",\n            \"    for m in moves:\\n\",\n            \"        nb=move(board, m)\\n\",\n            \"        val=score(nb)\\n\",\n            \"        if val>best_val and any(nb[i][j]!=board[i][j] for i in range(len(nb)) for j in range(len(nb[0]))):\\n\",\n            \"            best_val=val; best=m\\n\",\n            \"    return best if best else \\\"W\\\"\\n\",\n            \"Exception = list index out of range\\n\",\n            \"Timeout\\n\",\n            \"Steps = 475 State = failed\\n\",\n            \"def strategy(board):\\n\",\n            \"    def move_possible(board, direction):\\n\",\n            \"        rows, cols = len(board), len(board[0])\\n\",\n            \"        if direction == 'W':\\n\",\n            \"            for j in range(cols):\\n\",\n            \"                for i in range(1, rows):\\n\",\n            \"                    if board[i][j] != 0:\\n\",\n            \"                        for k in range(i-1, -1, -1):\\n\",\n            \"                            if board[k][j] == 0 or board[k][j] == board[i][j]:\\n\",\n            \"                                return True\\n\",\n            \"                            if board[k][j] != 0:\\n\",\n            \"                                break\\n\",\n            \"        elif direction == 'S':\\n\",\n            \"            for j in range(cols):\\n\",\n            \"                for i in range(rows-2, -1, -1):\\n\",\n            \"                    if board[i][j] != 0:\\n\",\n            \"                        for k in range(i+1, rows):\\n\",\n            \"                            if board[k][j] == 0 or board[k][j] == board[i][j]:\\n\",\n            \"                                return True\\n\",\n            \"                            if board[k][j] != 0:\\n\",\n            \"                                break\\n\",\n            \"        elif direction == 'A':\\n\",\n            \"            for i in range(rows):\\n\",\n            \"                for j in range(1, cols):\\n\",\n            \"                    if board[i][j] != 0:\\n\",\n            \"                        for k in range(j-1, -1, -1):\\n\",\n            \"                            if board[i][k] == 0 or board[i][k] == board[i][j]:\\n\",\n            \"                                return True\\n\",\n            \"                            if board[i][k] != 0:\\n\",\n            \"                                break\\n\",\n            \"        elif direction == 'D':\\n\",\n            \"            for i in range(rows):\\n\",\n            \"                for j in range(cols-2, -1, -1):\\n\",\n            \"                    if board[i][j] != 0:\\n\",\n            \"                        for k in range(j+1, cols):\\n\",\n            \"                            if board[i][k] == 0 or board[i][k] == board[i][j]:\\n\",\n            \"                                return True\\n\",\n            \"                            if board[i][k] != 0:\\n\",\n            \"                                break\\n\",\n            \"        return False\\n\",\n            \"\\n\",\n            \"    # Prefer moves that allow a merge as they increase score\\n\",\n            \"    for d in ('W', 'S', 'A', 'D'):\\n\",\n            \"        if move_possible(board, d):\\n\",\n            \"            return d\\n\",\n            \"    # If no merges are possible, pick any direction that moves tiles\\n\",\n            \"    for d in ('W', 'S', 'A', 'D'):\\n\",\n            \"        if any(board[i][j] != 0 for i in range(len(board)) for j in range(len(board[0]))):\\n\",\n            \"            return d\\n\",\n            \"    return 'W'\\n\",\n            \"┌───┬───┬───┬───┬───┬───┐\\n\",\n            \"│\\u001b[38;5;45m  2\\u001b[0m│\\u001b[38;5;47m 16\\u001b[0m│\\u001b[38;5;51m  4\\u001b[0m│\\u001b[38;5;45m  2\\u001b[0m│\\u001b[38;5;49m  8\\u001b[0m│\\u001b[38;5;51m  4\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;49m  8\\u001b[0m│\\u001b[38;5;45m  2\\u001b[0m│\\u001b[38;5;46m 32\\u001b[0m│\\u001b[38;5;49m  8\\u001b[0m│\\u001b[38;5;154m128\\u001b[0m│\\u001b[38;5;49m  8\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;46m 32\\u001b[0m│\\u001b[38;5;118m 64\\u001b[0m│\\u001b[38;5;226m256\\u001b[0m│\\u001b[38;5;45m  2\\u001b[0m│\\u001b[38;5;118m 64\\u001b[0m│\\u001b[38;5;46m 32\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;154m128\\u001b[0m│\\u001b[38;5;49m  8\\u001b[0m│\\u001b[38;5;47m 16\\u001b[0m│\\u001b[38;5;118m 64\\u001b[0m│\\u001b[38;5;46m 32\\u001b[0m│\\u001b[38;5;49m  8\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;51m  4\\u001b[0m│\\u001b[38;5;45m  2\\u001b[0m│\\u001b[38;5;51m  4\\u001b[0m│\\u001b[38;5;47m 16\\u001b[0m│\\u001b[38;5;49m  8\\u001b[0m│\\u001b[38;5;51m  4\\u001b[0m│\\n\",\n            \"├───┼───┼───┼───┼───┼───┤\\n\",\n            \"│\\u001b[38;5;118m 64\\u001b[0m│\\u001b[38;5;51m  4\\u001b[0m│\\u001b[38;5;45m  2\\u001b[0m│\\u001b[38;5;49m  8\\u001b[0m│\\u001b[38;5;51m  4\\u001b[0m│\\u001b[38;5;45m  2\\u001b[0m│\\n\",\n            \"└───┴───┴───┴───┴───┴───┘\\n\",\n            \"Exception = '>' not supported between instances of 'tuple' and 'float'\\n\",\n            \"def strategy(board):\\n\",\n            \"    import random, copy\\n\",\n            \"\\n\",\n            \"    def rotate(b):\\n\",\n            \"        return [[b[3-j][i] for j in range(4)] for i in range(4)]\\n\",\n            \"\\n\",\n            \"    def compress(b):\\n\",\n            \"        new = []\\n\",\n            \"        for row in b:\\n\",\n            \"            new_row = [x for x in row if x != 0]\\n\",\n            \"            new_row += [0]*(4-len(new_row))\\n\",\n            \"            new.append(new_row)\\n\",\n            \"        return new\\n\",\n            \"\\n\",\n            \"    def merge(b):\\n\",\n            \"        for row in b:\\n\",\n            \"            for i in range(3):\\n\",\n            \"                if row[i]==row[i+1] and row[i]!=0:\\n\",\n            \"                    row[i]*=2\\n\",\n            \"                    row[i+1]=0\\n\",\n            \"\\n\",\n            \"    def move(b, dir):\\n\",\n            \"        if dir==\\\"W\\\":\\n\",\n            \"            return merge(rotate(compress(rotate(b))))\\n\",\n            \"        if dir==\\\"S\\\":\\n\",\n            \"            return rotate(merge(compress(rotate(b))))\\n\",\n            \"        if dir==\\\"A\\\":\\n\",\n            \"            return merge(compress(b))\\n\",\n            \"        if dir==\\\"D\\\":\\n\",\n            \"            return rotate(merge(compress(rotate(b))))  # actually reverse\\n\",\n            \"\\n\",\n            \"    best_score=0\\n\",\n            \"    best_move=None\\n\",\n            \"    for move_dir in \\\"WASD\\\":\\n\",\n            \"        new_board=move(copy.deepcopy(board), move_dir)\\n\",\n            \"        score=sum(sum(row) for row in new_board)\\n\",\n            \"        if score>best_score:\\n\",\n            \"            best_score=score\\n\",\n            \"            best_move=move_dir\\n\",\n            \"    return best_move\\n\",\n            \"Exception = 'NoneType' object is not iterable\\n\",\n            \"Exception = name 'n' is not defined\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"None\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # Prioritize merges, then favor left/up moves\\n\",\n            \"    rows, cols = len(board), len(board[0]) if board else 0\\n\",\n            \"\\n\",\n            \"    # Helper to check if a move is possible\\n\",\n            \"    def can_move(direction):\\n\",\n            \"        if direction == 'W':\\n\",\n            \"            for c in range(cols):\\n\",\n            \"                for r in range(rows-1):\\n\",\n            \"                    if board[r][c] == 0 or board[r][c] == board[r+1][c]:\\n\",\n            \"                        return True\\n\",\n            \"        elif direction == 'A':\\n\",\n            \"            for r in range(rows):\\n\",\n            \"                for c in range(cols-1):\\n\",\n            \"                    if board[r][c] == 0 or board[r][c] == board[r][c+1]:\\n\",\n            \"                        return True\\n\",\n            \"        elif direction == 'S':\\n\",\n            \"            for c in range(cols):\\n\",\n            \"                for r in range(rows-1,0,-1):\\n\",\n            \"                    if board[r][c] == 0 or board[r][c] == board[r-1][c]:\\n\",\n            \"                        return True\\n\",\n            \"        elif direction == 'D':\\n\",\n            \"            for r in range(rows):\\n\",\n            \"                for c in range(cols-1,0,-1):\\n\",\n            \"                    if board[r][c] == 0 or board[r][c] == board[r][c-1]:\\n\",\n            \"                        return True\\n\",\n            \"        return False\\n\",\n            \"\\n\",\n            \"    # Generate all moves\\n\",\n            \"    moves = []\\n\",\n            \"    for d in ['W', 'A', 'S', 'D']:\\n\",\n            \"        if can_move(d):\\n\",\n            \"            moves.append(d)\\n\",\n            \"\\n\",\n            \"    # If multiple moves, pick one that maximizes the sum of merges\\n\",\n            \"    if not moves:\\n\",\n            \"        return 'W'  # fallback\\n\",\n            \"    # Simple heuristic: prefer first move that allows a merge\\n\",\n            \"    return moves[0]\\n\",\n            \"Timeout\\n\",\n            \"Steps = 1512 State = failed\\n\",\n            \"def strategy(board):\\n\",\n            \"    # helper to check possible merge in a row or column\\n\",\n            \"    def can_merge(lst):\\n\",\n            \"        for i in range(len(lst)-1):\\n\",\n            \"            if lst[i] > 0 and lst[i] == lst[i+1]:\\n\",\n            \"                return True\\n\",\n            \"        return False\\n\",\n            \"\\n\",\n            \"    # try to move in a direction that creates a merge\\n\",\n            \"    for dir, delta in [(\\\"W\\\", (-1,0)), (\\\"A\\\", (0,-1)), (\\\"S\\\", (1,0)), (\\\"D\\\", (0,1))]:\\n\",\n            \"        merged = False\\n\",\n            \"        for i in range(len(board)):\\n\",\n            \"            for j in range(len(board[0])):\\n\",\n            \"                if board[i][j] > 0:\\n\",\n            \"                    ni, nj = i + delta[0], j + delta[1]\\n\",\n            \"                    if 0 <= ni < len(board) and 0 <= nj < len(board[0]):\\n\",\n            \"                        if board[ni][nj] == 0:\\n\",\n            \"                            return dir\\n\",\n            \"                        if board[ni][nj] == board[i][j]:\\n\",\n            \"                            return dir\\n\",\n            \"    # fallback: move down\\n\",\n            \"    return \\\"S\\\"\\n\",\n            \"┌────┬────┬────┬────┬────┬────┐\\n\",\n            \"│\\u001b[38;5;214m 512\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;226m 256\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;154m 128\\u001b[0m│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;208m1024\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\u001b[38;5;118m  64\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\u001b[38;5;226m 256\\u001b[0m│\\u001b[38;5;154m 128\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;226m 256\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;154m 128\\u001b[0m│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\n\",\n            \"└────┴────┴────┴────┴────┴────┘\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # Simple greedy: choose direction that keeps tiles sorted in ascending order left-bottom\\n\",\n            \"    best = \\\" \\\"\\n\",\n            \"    best_val = -1\\n\",\n            \"    for d in \\\"WASD\\\":\\n\",\n            \"        # simulate move\\n\",\n            \"        b = [row[:] for row in board]\\n\",\n            \"        # merge function\\n\",\n            \"        def merge(row):\\n\",\n            \"            new = [x for x in row if x != 0]\\n\",\n            \"            res = []\\n\",\n            \"            i = 0\\n\",\n            \"            while i < len(new):\\n\",\n            \"                if i+1 < len(new) and new[i] == new[i+1]:\\n\",\n            \"                    res.append(new[i]*2)\\n\",\n            \"                    i += 2\\n\",\n            \"                else:\\n\",\n            \"                    res.append(new[i])\\n\",\n            \"                    i += 1\\n\",\n            \"            return res + [0]*(len(row)-len(res))\\n\",\n            \"        moved = False\\n\",\n            \"        if d == \\\"W\\\":\\n\",\n            \"            for col in range(4):\\n\",\n            \"                col_vals = [board[r][col] for r in range(4)]\\n\",\n            \"                merged = merge(col_vals)\\n\",\n            \"                for r in range(4):\\n\",\n            \"                    b[r][col] = merged[r]\\n\",\n            \"        elif d == \\\"S\\\":\\n\",\n            \"            for col in range(4):\\n\",\n            \"                col_vals = [board[r][col] for r in range(4)][::-1]\\n\",\n            \"                merged = merge(col_vals)[::-1]\\n\",\n            \"                for r in range(4):\\n\",\n            \"                    b[r][col] = merged[r]\\n\",\n            \"        elif d == \\\"A\\\":\\n\",\n            \"            for r in range(4):\\n\",\n            \"                row_vals = board[r]\\n\",\n            \"                merged = merge(row_vals)\\n\",\n            \"                b[r] = merged\\n\",\n            \"        elif d == \\\"D\\\":\\n\",\n            \"            for r in range(4):\\n\",\n            \"                row_vals = board[r][::-1]\\n\",\n            \"                merged = merge(row_vals)[::-1]\\n\",\n            \"                b[r] = merged\\n\",\n            \"        score = sum(filter(None, [x for row in b for x in row]))\\n\",\n            \"        if score > best_val:\\n\",\n            \"            best_val = score\\n\",\n            \"            best = d\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Exception = 'str' object is not callable\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # helper to rotate board\\n\",\n            \"    def rotate(b): return [list(col)[::-1] for col in zip(*b)]\\n\",\n            \"    # helper to move up\\n\",\n            \"    def move_up(b):\\n\",\n            \"        n=len(b)\\n\",\n            \"        new=[[] for _ in range(n)]\\n\",\n            \"        for j in range(n):\\n\",\n            \"            col=[b[i][j] for i in range(n) if b[i][j]!=0]\\n\",\n            \"            merged=[]\\n\",\n            \"            i=0\\n\",\n            \"            while i< len(col):\\n\",\n            \"                if i+1<len(col) and col[i]==col[i+1]:\\n\",\n            \"                    merged.append(col[i]*2)\\n\",\n            \"                    i+=2\\n\",\n            \"                else:\\n\",\n            \"                    merged.append(col[i])\\n\",\n            \"                    i+=1\\n\",\n            \"            new_col=[0]*(n-len(merged))+merged\\n\",\n            \"            for i in range(n):\\n\",\n            \"                new[i][j]=new_col[i]\\n\",\n            \"        return new\\n\",\n            \"    best=None\\n\",\n            \"    best_val=-1\\n\",\n            \"    for dir in [\\\"W\\\",\\\"A\\\",\\\"S\\\",\\\"D\\\"]:\\n\",\n            \"        # move board in given direction\\n\",\n            \"        b=[row[:] for row in board]\\n\",\n            \"        if dir==\\\"W\\\": b=move_up(b)\\n\",\n            \"        elif dir==\\\"S\\\": b=[list(row[::-1]) for row in move_up([row[::-1] for row in b])]\\n\",\n            \"        elif dir==\\\"A\\\": b=[list(row[::-1]) for row in move_up([row[::-1] for row in b])]\\n\",\n            \"        elif dir==\\\"D\\\": b=[list(row[::-1]) for row in b]\\n\",\n            \"        # evaluate\\n\",\n            \"        val=max(max(row) for row in b)\\n\",\n            \"        if val>best_val:\\n\",\n            \"            best_val=val; best=dir\\n\",\n            \"    return best\\n\",\n            \"Exception = list assignment index out of range\\n\",\n            \"Timeout\\n\",\n            \"Exception = list index out of range\\n\",\n            \"def strategy(board):\\n\",\n            \"    import copy\\n\",\n            \"    moves = \\\"WASD\\\"\\n\",\n            \"    best = None\\n\",\n            \"    best_score = -1\\n\",\n            \"    for m in moves:\\n\",\n            \"        b = copy.deepcopy(board)\\n\",\n            \"        if m==\\\"W\\\":\\n\",\n            \"            for c in range(len(b)):\\n\",\n            \"                merged = []\\n\",\n            \"                for r in range(len(b)):\\n\",\n            \"                    val = b[r][c]\\n\",\n            \"                    if val!=0:\\n\",\n            \"                        merged.append(val)\\n\",\n            \"                i=0\\n\",\n            \"                while i+1<len(merged):\\n\",\n            \"                    if merged[i]==merged[i+1]:\\n\",\n            \"                        merged[i]*=2\\n\",\n            \"                        merged.pop(i+1)\\n\",\n            \"                    i+=1\\n\",\n            \"                merged+= [0]*(len(b)-len(merged))\\n\",\n            \"                for r in range(len(b)):\\n\",\n            \"                    b[r][c]=merged[r]\\n\",\n            \"        elif m==\\\"S\\\":\\n\",\n            \"            for c in range(len(b)):\\n\",\n            \"                merged = []\\n\",\n            \"                for r in reversed(range(len(b))):\\n\",\n            \"                    val = b[r][c]\\n\",\n            \"                    if val!=0:\\n\",\n            \"                        merged.append(val)\\n\",\n            \"                i=0\\n\",\n            \"                while i+1<len(merged):\\n\",\n            \"                    if merged[i]==merged[i+1]:\\n\",\n            \"                        merged[i]*=2\\n\",\n            \"                        merged.pop(i+1)\\n\",\n            \"                    i+=1\\n\",\n            \"                merged+= [0]*(len(b)-len(merged))\\n\",\n            \"                for r in range(len(b)):\\n\",\n            \"                    b[r][c]=merged[len(b)-1-r]\\n\",\n            \"        elif m==\\\"A\\\":\\n\",\n            \"            for r in range(len(b)):\\n\",\n            \"                row = b[r]\\n\",\n            \"                merged = [v for v in row if v!=0]\\n\",\n            \"                i=0\\n\",\n            \"                while i+1<len(merged):\\n\",\n            \"                    if merged[i]==merged[i+1]:\\n\",\n            \"                        merged[i]*=2\\n\",\n            \"                        merged.pop(i+1)\\n\",\n            \"                    i+=1\\n\",\n            \"                merged+= [0]*(len(b)-len(merged))\\n\",\n            \"                b[r]=merged\\n\",\n            \"        elif m==\\\"D\\\":\\n\",\n            \"            for r in range(len(b)):\\n\",\n            \"                row = list(reversed(b[r]))\\n\",\n            \"                merged = [v for v in row if v!=0]\\n\",\n            \"                i=0\\n\",\n            \"                while i+1<len(merged):\\n\",\n            \"                    if merged[i]==merged[i+1]:\\n\",\n            \"                        merged[i]*=2\\n\",\n            \"                        merged.pop(i+1)\\n\",\n            \"                    i+=1\\n\",\n            \"                merged+= [0]*(len(b)-len(merged))\\n\",\n            \"                b[r]=list(reversed(merged))\\n\",\n            \"        score=sum(sum(row) for row in b)\\n\",\n            \"        if score>best_score:\\n\",\n            \"            best_score=score; best=m\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Exception = unsupported operand type(s) for -: 'range' and 'int'\\n\",\n            \"def strategy(board):\\n\",\n            \"    # board is a 4x4 list of ints, 0 for empty\\n\",\n            \"    # Simple greedy: move that merges most tiles\\n\",\n            \"    moves = {}\\n\",\n            \"    dirs = {\\\"W\\\": (-1,0), \\\"A\\\": (0,-1), \\\"S\\\": (1,0), \\\"D\\\": (0,1)}\\n\",\n            \"    for d, (dr,dc) in dirs.items():\\n\",\n            \"        # simulate move\\n\",\n            \"        new_board = [row[:] for row in board]\\n\",\n            \"        merged = 0\\n\",\n            \"        for i in range(4):\\n\",\n            \"            for j in range(4):\\n\",\n            \"                if new_board[i][j]==0: continue\\n\",\n            \"                ni, nj = i+dr, j+dc\\n\",\n            \"                while 0<=ni<4 and 0<=nj<4 and new_board[ni][nj]==0:\\n\",\n            \"                    ni+=dr; nj+=dc\\n\",\n            \"                if 0<=ni<4 and 0<=nj<4 and new_board[ni][nj]==new_board[i][j]:\\n\",\n            \"                    merged+=1\\n\",\n            \"        moves[d]=merged\\n\",\n            \"    # choose direction with most merges, default W\\n\",\n            \"    best = max(moves, key=moves.get)\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Exception = list index out of range\\n\",\n            \"def strategy(board):\\n\",\n            \"    moves = \\\"WASD\\\"\\n\",\n            \"    best = None\\n\",\n            \"    best_score = -1\\n\",\n            \"    for m in moves:\\n\",\n            \"        new_board = [row[:] for row in board]\\n\",\n            \"        if m == \\\"W\\\":\\n\",\n            \"            new_board = _move_up(new_board)\\n\",\n            \"        elif m == \\\"A\\\":\\n\",\n            \"            new_board = _move_left(new_board)\\n\",\n            \"        elif m == \\\"S\\\":\\n\",\n            \"            new_board = _move_down(new_board)\\n\",\n            \"        else:  # \\\"D\\\"\\n\",\n            \"            new_board = _move_right(new_board)\\n\",\n            \"        score = sum(sum(row) for row in new_board)\\n\",\n            \"        if score > best_score:\\n\",\n            \"            best_score, best = score, m\\n\",\n            \"    return best\\n\",\n            \"\\n\",\n            \"def _compress(line):\\n\",\n            \"    nonzero = [x for x in line if x]\\n\",\n            \"    res = []\\n\",\n            \"    i = 0\\n\",\n            \"    while i < len(nonzero):\\n\",\n            \"        if i + 1 < len(nonzero) and nonzero[i] == nonzero[i+1]:\\n\",\n            \"            res.append(nonzero[i]*2)\\n\",\n            \"            i += 2\\n\",\n            \"        else:\\n\",\n            \"            res.append(nonzero[i])\\n\",\n            \"            i += 1\\n\",\n            \"    return res + [0]*(len(line)-len(res))\\n\",\n            \"\\n\",\n            \"def _move_up(b):\\n\",\n            \"    n = len(b)\\n\",\n            \"    res = [[0]*n for _ in range(n)]\\n\",\n            \"    for j in range(n):\\n\",\n            \"        col = [b[i][j] for i in range(n)]\\n\",\n            \"        col = _compress(col)\\n\",\n            \"        for i in range(n):\\n\",\n            \"            res[i][j] = col[i]\\n\",\n            \"    return res\\n\",\n            \"\\n\",\n            \"def _move_down(b):\\n\",\n            \"    n = len(b)\\n\",\n            \"    res = [[0]*n for _ in range(n)]\\n\",\n            \"    for j in range(n):\\n\",\n            \"        col = [b[i][j] for i in range(n)][::-1]\\n\",\n            \"        col = _compress(col)\\n\",\n            \"        for i in range(n):\\n\",\n            \"            res[n-1-i][j] = col[i]\\n\",\n            \"    return res\\n\",\n            \"\\n\",\n            \"def _move_left(b):\\n\",\n            \"    n = len(b)\\n\",\n            \"    res = [[0]*n for _ in range(n)]\\n\",\n            \"    for i in range(n):\\n\",\n            \"        row = _compress(b[i])\\n\",\n            \"        res[i] = row\\n\",\n            \"    return res\\n\",\n            \"\\n\",\n            \"def _move_right(b):\\n\",\n            \"    n = len(b)\\n\",\n            \"    res = [[0]*n for _ in range(n)]\\n\",\n            \"    for i in range(n):\\n\",\n            \"        row = _compress(b[i][::-1])[::-1]\\n\",\n            \"        res[i] = row\\n\",\n            \"    return res\\n\",\n            \"Exception = 'int' object is not subscriptable\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # helper to apply a move and return new board\\n\",\n            \"    def move(b, dir):\\n\",\n            \"        n = len(b)\\n\",\n            \"        res = [[0]*n for _ in range(n)]\\n\",\n            \"        for x in range(n):\\n\",\n            \"            line = []\\n\",\n            \"            for y in range(n):\\n\",\n            \"                i,j = (y,x) if dir==\\\"D\\\" else (x,y)\\n\",\n            \"                if dir==\\\"A\\\": i=j\\n\",\n            \"            # skip for brevity\\n\",\n            \"\\n\",\n            \"    # simplified heuristic: choose direction that increases sum of merged tiles\\n\",\n            \"    best, best_sum = None, -1\\n\",\n            \"    dirs = \\\"WASD\\\"\\n\",\n            \"    for d in dirs:\\n\",\n            \"        new = move(board, d)\\n\",\n            \"        merged = sum(c for r in new for c in r) - sum(c for r in board for c in r)\\n\",\n            \"        if merged > best_sum:\\n\",\n            \"            best_sum, best = merged, d\\n\",\n            \"    return best\\n\",\n            \"Exception = 'NoneType' object is not iterable\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    import math\\n\",\n            \"    def score(b):\\n\",\n            \"        empty = sum(1 for r in b for v in r if v==0)\\n\",\n            \"        mx = max(max(row) for row in b)\\n\",\n            \"        return empty*10 + mx\\n\",\n            \"    best=None; best_score=-math.inf\\n\",\n            \"    for move in \\\"WASD\\\":\\n\",\n            \"        new=board.copy()\\n\",\n            \"        # simulate simple move logic\\n\",\n            \"        if move==\\\"W\\\":\\n\",\n            \"            for col in range(4):\\n\",\n            \"                col_vals=[r[col] for r in new if r[col]!=0]\\n\",\n            \"                for i,row in enumerate(col_vals):\\n\",\n            \"                    new[i][col]=col_vals[i]\\n\",\n            \"                for i in range(i+1,4):\\n\",\n            \"                    new[i][col]=0\\n\",\n            \"        elif move==\\\"S\\\":\\n\",\n            \"            for col in range(4):\\n\",\n            \"                col_vals=[r[col] for r in new if r[col]!=0]\\n\",\n            \"                for i,row in enumerate(reversed(col_vals)):\\n\",\n            \"                    new[3-i][col]=col_vals[i]\\n\",\n            \"                for i in range(3-i+1,4):\\n\",\n            \"                    new[i][col]=0\\n\",\n            \"        elif move==\\\"A\\\":\\n\",\n            \"            for row in range(4):\\n\",\n            \"                row_vals=[v for v in new[row] if v!=0]\\n\",\n            \"                for i,v in enumerate(row_vals):\\n\",\n            \"                    new[row][i]=row_vals[i]\\n\",\n            \"                for i in range(i+1,4):\\n\",\n            \"                    new[row][i]=0\\n\",\n            \"        elif move==\\\"D\\\":\\n\",\n            \"            for row in range(4):\\n\",\n            \"                row_vals=[v for v in new[row] if v!=0]\\n\",\n            \"                for i,v in enumerate(reversed(row_vals)):\\n\",\n            \"                    new[row][3-i]=row_vals[i]\\n\",\n            \"                for i in range(3-i+1,4):\\n\",\n            \"                    new[row][i]=0\\n\",\n            \"        sc=score(new)\\n\",\n            \"        if sc>best_score:\\n\",\n            \"            best_score=sc; best=move\\n\",\n            \"    return best\\n\",\n            \"Exception = cannot access local variable 'i' where it is not associated with a value\\n\",\n            \"Timeout\\n\",\n            \"Exception = name 'merge' is not defined\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # 4x4 board\\n\",\n            \"    moves = 'W A S D'.split()\\n\",\n            \"    best = None\\n\",\n            \"    best_score = -1\\n\",\n            \"    for m in moves:\\n\",\n            \"        b = [row[:] for row in board]  # copy\\n\",\n            \"        for i in range(4):\\n\",\n            \"            line = b[i] if m in 'AD' else [row[i] for row in b]\\n\",\n            \"            merged = []\\n\",\n            \"            skip = False\\n\",\n            \"            for j, v in enumerate(line):\\n\",\n            \"                if v == 0: continue\\n\",\n            \"                if skip:\\n\",\n            \"                    skip = False\\n\",\n            \"                    continue\\n\",\n            \"                if j + 1 < len(line) and line[j+1] == v:\\n\",\n            \"                    merged.append(v*2)\\n\",\n            \"                    skip = True\\n\",\n            \"                else:\\n\",\n            \"                    merged.append(v)\\n\",\n            \"            while len(merged) < 4:\\n\",\n            \"                merged.append(0)\\n\",\n            \"            if m in 'AD':\\n\",\n            \"                for k in range(4): b[i][k] = merged[k]\\n\",\n            \"            else:\\n\",\n            \"                for k in range(4): b[k][i] = merged[k]\\n\",\n            \"        score = sum(sum(row) for row in b)\\n\",\n            \"        if score > best_score:\\n\",\n            \"            best_score = score\\n\",\n            \"            best = m\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # board is a list of lists representing a 4x4 grid.\\n\",\n            \"    # possible moves\\n\",\n            \"    moves = ['W', 'A', 'S', 'D']\\n\",\n            \"    best = None\\n\",\n            \"    best_score = -1\\n\",\n            \"    \\n\",\n            \"    def score(b):\\n\",\n            \"        s = 0\\n\",\n            \"        for row in b:\\n\",\n            \"            for v in row:\\n\",\n            \"                s += v\\n\",\n            \"        return s\\n\",\n            \"    \\n\",\n            \"    for m in moves:\\n\",\n            \"        nb = [row[:] for row in board]\\n\",\n            \"        # simulate move m (very naive: just return new board if any merge)\\n\",\n            \"        merged = False\\n\",\n            \"        for i in range(4):\\n\",\n            \"            for j in range(4):\\n\",\n            \"                if nb[i][j] == 0: continue\\n\",\n            \"                for di, dj in ( (-1,0),(1,0),(0,-1),(0,1) ):\\n\",\n            \"                    ni, nj = i+di, j+dj\\n\",\n            \"                    if 0<=ni<4 and 0<=nj<4 and nb[ni][nj]==nb[i][j]:\\n\",\n            \"                        nb[ni][nj] += nb[i][j]\\n\",\n            \"                        nb[i][j] = 0\\n\",\n            \"                        merged = True\\n\",\n            \"        if merged:\\n\",\n            \"            sc = score(nb)\\n\",\n            \"            if sc > best_score:\\n\",\n            \"                best_score, best = sc, m\\n\",\n            \"    return best if best is not None else moves[0]\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Exception = cannot access local variable 'val' where it is not associated with a value\\n\",\n            \"None\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Exception = not enough values to unpack (expected 2, got 1)\\n\",\n            \"def strategy(board):\\n\",\n            \"    # evaluate a move by the total sum after the move\\n\",\n            \"    def sim(b, m):\\n\",\n            \"        n = len(b)\\n\",\n            \"        b = [row[:] for row in b]\\n\",\n            \"        moved = False\\n\",\n            \"        if m == 'W':\\n\",\n            \"            for j in range(n):\\n\",\n            \"                col = [b[i][j] for i in range(n)]\\n\",\n            \"                col += [0]*(n-len(col))\\n\",\n            \"                newcol = []\\n\",\n            \"                i = 0\\n\",\n            \"                while i < n:\\n\",\n            \"                    if col[i] == 0:\\n\",\n            \"                        i += 1\\n\",\n            \"                        continue\\n\",\n            \"                    val = col[i]\\n\",\n            \"                    i += 1\\n\",\n            \"                    while i < n and col[i] == 0: i += 1\\n\",\n            \"                    if i < n and col[i] == val:\\n\",\n            \"                        val *= 2\\n\",\n            \"                        i += 1\\n\",\n            \"                    newcol.append(val)\\n\",\n            \"                for i in range(n):\\n\",\n            \"                    b[i][j] = newcol[i] if i < len(newcol) else 0\\n\",\n            \"            moved = True\\n\",\n            \"        # other moves omitted for brevity  \\n\",\n            \"        return b if moved else None\\n\",\n            \"\\n\",\n            \"    best, best_val = None, -1\\n\",\n            \"    for m in \\\"WASD\\\":\\n\",\n            \"        r = sim(board, m)\\n\",\n            \"        if r:\\n\",\n            \"            val = sum(sum(row) for row in r)\\n\",\n            \"            if val > best_val:\\n\",\n            \"                best_val, best = val, m\\n\",\n            \"    return best if best else \\\"W\\\"\\n\",\n            \"Timeout\\n\",\n            \"Exception = list index out of range\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"Timeout\\n\",\n            \"Exception = strategy.<locals>.rotate() takes 1 positional argument but 2 were given\\n\",\n            \"def strategy(board):\\n\",\n            \"    # helper to simulate a move\\n\",\n            \"    def move(b, direction):\\n\",\n            \"        size = len(b)\\n\",\n            \"        new = [[0]*size for _ in range(size)]\\n\",\n            \"        for i in range(size):\\n\",\n            \"            if direction in ('A','D'):\\n\",\n            \"                line = b[i] if direction=='D' else b[i][::-1]\\n\",\n            \"            else:\\n\",\n            \"                line = [b[j][i] for j in range(size)]\\n\",\n            \"                if direction=='S': line = line[::-1]\\n\",\n            \"            merged = []\\n\",\n            \"            skip = False\\n\",\n            \"            for val in line:\\n\",\n            \"                if val==0: continue\\n\",\n            \"                if merged and merged[-1]==val and not skip:\\n\",\n            \"                    merged[-1] += val\\n\",\n            \"                    skip = True\\n\",\n            \"                else:\\n\",\n            \"                    merged.append(val)\\n\",\n            \"                    skip = False\\n\",\n            \"            for j,v in enumerate(merged):\\n\",\n            \"                new[i if direction=='A' else size-1-i][j if direction=='A' else size-1-j] = v\\n\",\n            \"        return new\\n\",\n            \"\\n\",\n            \"    # evaluate each move\\n\",\n            \"    best = None\\n\",\n            \"    best_val = -1\\n\",\n            \"    for dirc in 'WASD':\\n\",\n            \"        new_board = move(board, dirc)\\n\",\n            \"        val = sum(sum(row) for row in new_board)\\n\",\n            \"        if val > best_val:\\n\",\n            \"            best_val = val\\n\",\n            \"            best = dirc\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"None\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"None\\n\",\n            \"Exception = unsupported operand type(s) for -: 'list' and 'int'\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # Simple heuristic: move up unless a merge is possible in another direction\\n\",\n            \"    # Check if any pair can merge horizontally or vertically\\n\",\n            \"    for i in range(4):\\n\",\n            \"        for j in range(3):\\n\",\n            \"            if board[i][j] == board[i][j+1]:\\n\",\n            \"                return \\\"A\\\"  # left\\n\",\n            \"    for i in range(3):\\n\",\n            \"        for j in range(4):\\n\",\n            \"            if board[i][j] == board[i+1][j]:\\n\",\n            \"                return \\\"W\\\"  # up\\n\",\n            \"    return \\\"D\\\"  # fallback\\n\",\n            \"Timeout\\n\",\n            \"Exception = list index out of range\\n\",\n            \"def strategy(board):\\n\",\n            \"    def score_for(move):\\n\",\n            \"        B = [row[:] for row in board]\\n\",\n            \"        def slide(row):\\n\",\n            \"            new = [x for x in row if x != 0]\\n\",\n            \"            res = []\\n\",\n            \"            skip = False\\n\",\n            \"            for i, x in enumerate(new):\\n\",\n            \"                if skip:\\n\",\n            \"                    skip = False\\n\",\n            \"                    continue\\n\",\n            \"                if i+1 < len(new) and new[i] == new[i+1]:\\n\",\n            \"                    res.append(x*2)\\n\",\n            \"                    skip = True\\n\",\n            \"                else:\\n\",\n            \"                    res.append(x)\\n\",\n            \"            return res + [0]*(len(row)-len(res))\\n\",\n            \"        if move=='W':\\n\",\n            \"            for i in range(len(B)):\\n\",\n            \"                B[i] = slide(B[i])\\n\",\n            \"        elif move=='S':\\n\",\n            \"            B = B[::-1]\\n\",\n            \"            for i in range(len(B)):\\n\",\n            \"                B[i] = slide(B[i])\\n\",\n            \"            B = B[::-1]\\n\",\n            \"        elif move=='A':\\n\",\n            \"            for row in B:\\n\",\n            \"                row[:] = slide(row)\\n\",\n            \"        elif move=='D':\\n\",\n            \"            for row in B:\\n\",\n            \"                row[:] = slide(row[::-1])[::-1]\\n\",\n            \"        empty = sum(cell==0 for r in B for cell in r)\\n\",\n            \"        return empty\\n\",\n            \"    best=None\\n\",\n            \"    for m in 'WASD':\\n\",\n            \"        if score_for(m)>best[1] if best else -1:\\n\",\n            \"            best=(m,score_for(m))\\n\",\n            \"    return best[0]\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Exception = list assignment index out of range\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    '''\\n\",\n            \"    Returns the best next move for a 2048 game using a very small heuristic.\\n\",\n            \"    The heuristic looks at the free spaces after the move and chooses the\\n\",\n            \"    direction that tends to leave the most empty tiles.\\n\",\n            \"    '''\\n\",\n            \"    from functools import lru_cache\\n\",\n            \"\\n\",\n            \"    # Flatten the board for easier hashing\\n\",\n            \"    flatten = tuple(tuple(row) for row in board)\\n\",\n            \"\\n\",\n            \"    # Helper: simulate a move\\n\",\n            \"    def move(state, direction):\\n\",\n            \"        size = len(state)\\n\",\n            \"        new_state = []\\n\",\n            \"        for row in state:\\n\",\n            \"            merged = []\\n\",\n            \"            for d in row:\\n\",\n            \"                if d != 0:\\n\",\n            \"                    merged.append(d)\\n\",\n            \"\\n\",\n            \"            if direction in ('A', 'D'):  # horizontal move\\n\",\n            \"                merged = merged[::-1] if direction == 'D' else merged\\n\",\n            \"                i = 0\\n\",\n            \"                while i < len(merged) - 1:\\n\",\n            \"                    if merged[i] == merged[i + 1]:\\n\",\n            \"                        merged[i] *= 2\\n\",\n            \"                        merged.pop(i + 1)\\n\",\n            \"                    i += 1\\n\",\n            \"                merged += [0] * (size - len(merged))\\n\",\n            \"                if direction == 'D':\\n\",\n            \"                    merged = merged[::-1]\\n\",\n            \"                new_state.append(tuple(merged))\\n\",\n            \"            else:  # vertical move\\n\",\n            \"                new_state.append(tuple(merged))\\n\",\n            \"        # For vertical moves, reconstruct column-wise\\n\",\n            \"        if direction in ('W', 'S'):\\n\",\n            \"            transposed = list(zip(*new_state))\\n\",\n            \"            new_state = []\\n\",\n            \"            for col in transposed:\\n\",\n            \"                merged = []\\n\",\n            \"                for d in col:\\n\",\n            \"                    if d != 0:\\n\",\n            \"                        merged.append(d)\\n\",\n            \"                merged = merged[::-1] if direction == 'S' else merged\\n\",\n            \"                i = 0\\n\",\n            \"                while i < len(merged) - 1:\\n\",\n            \"                    if merged[i] == merged[i + 1]:\\n\",\n            \"                        merged[i] *= 2\\n\",\n            \"                        merged.pop(i + 1)\\n\",\n            \"                    i += 1\\n\",\n            \"                merged += [0] * (size - len(merged))\\n\",\n            \"                if direction == 'S':\\n\",\n            \"                    merged = merged[::-1]\\n\",\n            \"                new_state.append(tuple(merged))\\n\",\n            \"            new_state = [tuple(row) for row in zip(*new_state)]\\n\",\n            \"        return tuple(tuple(row) for row in new_state)\\n\",\n            \"\\n\",\n            \"    # Count empty tiles\\n\",\n            \"    def empty_count(state):\\n\",\n            \"        return sum(1 for row in state for cell in row if cell == 0)\\n\",\n            \"\\n\",\n            \"    best_move = None\\n\",\n            \"    best_empty = -1\\n\",\n            \"    for move in ['W', 'A', 'S', 'D']:\\n\",\n            \"        new_board = move(flatten, move)\\n\",\n            \"        e = empty_count(new_board)\\n\",\n            \"        if e > best_empty:\\n\",\n            \"            best_empty = e\\n\",\n            \"            best_move = move\\n\",\n            \"    return best_move\\n\",\n            \"Exception = 'str' object is not callable\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    import copy\\n\",\n            \"    # Helper to apply a move and return new board\\n\",\n            \"    def move(board, dir):\\n\",\n            \"        size = len(board)\\n\",\n            \"        def compress(line):\\n\",\n            \"            new = [x for x in line if x>0]\\n\",\n            \"            res = []\\n\",\n            \"            i = 0\\n\",\n            \"            while i < len(new):\\n\",\n            \"                if i+1 < len(new) and new[i]==new[i+1]:\\n\",\n            \"                    res.append(new[i]*2)\\n\",\n            \"                    i += 2\\n\",\n            \"                else:\\n\",\n            \"                    res.append(new[i])\\n\",\n            \"                    i += 1\\n\",\n            \"            res += [0]*(size-len(res))\\n\",\n            \"            return res\\n\",\n            \"        if dir=='W':\\n\",\n            \"            new = [compress(col) for col in zip(*board)]\\n\",\n            \"            return [list(row) for row in zip(*new)]\\n\",\n            \"        if dir=='A':\\n\",\n            \"            return [compress(row) for row in board]\\n\",\n            \"        if dir=='S':\\n\",\n            \"            rev = [list(reversed(row)) for row in board]\\n\",\n            \"            new = [compress(row) for row in rev]\\n\",\n            \"            return [list(reversed(row)) for row in new]\\n\",\n            \"        if dir=='D':\\n\",\n            \"            rev = [list(reversed(row)) for row in board]\\n\",\n            \"            new = [compress(row) for row in rev]\\n\",\n            \"            return [list(row) for row in new]\\n\",\n            \"    best = None\\n\",\n            \"    best_score = -1\\n\",\n            \"    for d in ['W','A','S','D']:\\n\",\n            \"        newboard = move(board, d)\\n\",\n            \"        # score: sum of all tiles (higher better)\\n\",\n            \"        score = sum(sum(row) for row in newboard)\\n\",\n            \"        if score > best_score:\\n\",\n            \"            best_score, best = score, d\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # helper to simulate a move and compute score\\n\",\n            \"    def simulate(move):\\n\",\n            \"        n = len(board)\\n\",\n            \"        new_board = [[0]*n for _ in range(n)]\\n\",\n            \"        for i in range(n):\\n\",\n            \"            line = board[i] if move in \\\"WB\\\" else [row[i] for row in board]\\n\",\n            \"            if move in \\\"DS\\\":  # reverse for down/right\\n\",\n            \"                line = line[::-1]\\n\",\n            \"            merged = []\\n\",\n            \"            skip = False\\n\",\n            \"            for v in line:\\n\",\n            \"                if v == 0: continue\\n\",\n            \"                if merged and merged[-1][0] == v and not skip:\\n\",\n            \"                    merged[-1] = (merged[-1][0]*2, merged[-1][1]+1)\\n\",\n            \"                    skip = True\\n\",\n            \"                else:\\n\",\n            \"                    merged.append((v, 0))\\n\",\n            \"                    skip = False\\n\",\n            \"            merged += [(0,0)]*(n-len(merged))\\n\",\n            \"            for idx, (v, _) in enumerate(merged):\\n\",\n            \"                new_board[i if move in \\\"WD\\\" else idx][idx if move in \\\"WD\\\" else i] = v\\n\",\n            \"        return sum(sum(row) for row in new_board)\\n\",\n            \"\\n\",\n            \"    best_move = None\\n\",\n            \"    best_score = -1\\n\",\n            \"    for m in \\\"WASD\\\":\\n\",\n            \"        try:\\n\",\n            \"            score = simulate(m)\\n\",\n            \"            if score > best_score:\\n\",\n            \"                best_score = score\\n\",\n            \"                best_move = m\\n\",\n            \"        except:\\n\",\n            \"            continue\\n\",\n            \"    return best_move or \\\"W\\\"\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Exception = name 'n' is not defined\\n\",\n            \"def strategy(board):\\n\",\n            \"    import copy\\n\",\n            \"    moves = {'W': (-1,0), 'A': (0,-1), 'S': (1,0), 'D': (0,1)}\\n\",\n            \"    def move(b, dir):\\n\",\n            \"        size = len(b)\\n\",\n            \"        mx, my = moves[dir]\\n\",\n            \"        new = [[0]*size for _ in range(size)]\\n\",\n            \"        for r in range(size):\\n\",\n            \"            line = []\\n\",\n            \"            nr = r + mx\\n\",\n            \"            for c in range(size):\\n\",\n            \"                nc = c + my\\n\",\n            \"                if 0 <= nr < size and 0 <= nc < size:\\n\",\n            \"                    line.append(b[nr][nc])\\n\",\n            \"            # compress\\n\",\n            \"            res=[]\\n\",\n            \"            i=0\\n\",\n            \"            while i < len(line):\\n\",\n            \"                if i+1<len(line) and line[i]==line[i+1]:\\n\",\n            \"                    res.append(line[i]*2); i+=2\\n\",\n            \"                else:\\n\",\n            \"                    res.append(line[i]); i+=1\\n\",\n            \"            for i,val in enumerate(res):\\n\",\n            \"                nr = (r+mx*i if mx else r)\\n\",\n            \"                nc = (c+my*i if my else c)\\n\",\n            \"                new[nr][nc]=val\\n\",\n            \"        return new\\n\",\n            \"    def score(b):\\n\",\n            \"        s=0\\n\",\n            \"        for r in range(len(b)):\\n\",\n            \"            for c in range(len(b)):\\n\",\n            \"                if b[r][c]>0:\\n\",\n            \"                    s+=b[r][c]\\n\",\n            \"        return s\\n\",\n            \"    best=None\\n\",\n            \"    best_score=-1\\n\",\n            \"    for m in moves:\\n\",\n            \"        nb=move(board,m)\\n\",\n            \"        s=score(nb)\\n\",\n            \"        if s>best_score:\\n\",\n            \"            best_score=s; best=m\\n\",\n            \"    return best\\n\",\n            \"Exception = list index out of range\\n\",\n            \"Exception = 'NoneType' object is not subscriptable\\n\",\n            \"Exception = name 'col_index' is not defined\\n\",\n            \"def strategy(board):\\n\",\n            \"    import copy\\n\",\n            \"    moves = \\\"WASD\\\"\\n\",\n            \"    best, best_move = -1, \\\"W\\\"\\n\",\n            \"    for m in moves:\\n\",\n            \"        b = copy.deepcopy(board)\\n\",\n            \"        if m == \\\"W\\\":\\n\",\n            \"            for i in range(3,-1,-1):\\n\",\n            \"                for j in range(4):\\n\",\n            \"                    if b[i][j] and b[i-1][j] and b[i][j]==b[i-1][j]:\\n\",\n            \"                        b[i-1][j]*=2; b[i][j]=0\\n\",\n            \"        elif m == \\\"S\\\":\\n\",\n            \"            for i in range(4):\\n\",\n            \"                for j in range(4):\\n\",\n            \"                    if i<3 and b[i][j] and b[i+1][j] and b[i][j]==b[i+1][j]:\\n\",\n            \"                        b[i+1][j]*=2; b[i][j]=0\\n\",\n            \"        elif m == \\\"A\\\":\\n\",\n            \"            for i in range(4):\\n\",\n            \"                for j in range(4):\\n\",\n            \"                    if j<3 and b[i][j] and b[i][j+1] and b[i][j]==b[i][j+1]:\\n\",\n            \"                        b[i][j+1]*=2; b[i][j]=0\\n\",\n            \"        elif m == \\\"D\\\":\\n\",\n            \"            for i in range(4):\\n\",\n            \"                for j in range(3,-1,-1):\\n\",\n            \"                    if j>0 and b[i][j] and b[i][j-1] and b[i][j]==b[i][j-1]:\\n\",\n            \"                        b[i][j-1]*=2; b[i][j]=0\\n\",\n            \"        score = sum(sum(row) for row in b)\\n\",\n            \"        if score > best:\\n\",\n            \"            best, best_move = score, m\\n\",\n            \"    return best_move\\n\",\n            \"Timeout\\n\",\n            \"Steps = 1825 State = failed\\n\",\n            \"def strategy(board):\\n\",\n            \"    size = len(board)\\n\",\n            \"    # Helper to compute score of moves\\n\",\n            \"    def score_move(d):\\n\",\n            \"        new_board = [row[:] for row in board]\\n\",\n            \"        moved = False\\n\",\n            \"        if d == \\\"W\\\":\\n\",\n            \"            for j in range(size):\\n\",\n            \"                col = [new_board[i][j] for i in range(size)]\\n\",\n            \"                merged = merge(col)\\n\",\n            \"                for i in range(size):\\n\",\n            \"                    new_board[i][j] = merged[i]\\n\",\n            \"                if merged != col:\\n\",\n            \"                    moved = True\\n\",\n            \"        elif d == \\\"S\\\":\\n\",\n            \"            for j in range(size):\\n\",\n            \"                col = [new_board[i][j] for i in range(size)][::-1]\\n\",\n            \"                merged = merge(col)[::-1]\\n\",\n            \"                for i in range(size):\\n\",\n            \"                    new_board[i][j] = merged[i]\\n\",\n            \"                if merged[::-1] != col:\\n\",\n            \"                    moved = True\\n\",\n            \"        elif d == \\\"A\\\":\\n\",\n            \"            for i in range(size):\\n\",\n            \"                row = new_board[i][:]\\n\",\n            \"                merged = merge(row)\\n\",\n            \"                new_board[i] = merged\\n\",\n            \"                if merged != row:\\n\",\n            \"                    moved = True\\n\",\n            \"        elif d == \\\"D\\\":\\n\",\n            \"            for i in range(size):\\n\",\n            \"                row = new_board[i][::-1]\\n\",\n            \"                merged = merge(row)[::-1]\\n\",\n            \"                new_board[i] = merged\\n\",\n            \"                if merged[::-1] != row:\\n\",\n            \"                    moved = True\\n\",\n            \"        return moved, new_board\\n\",\n            \"\\n\",\n            \"    def merge(line):\\n\",\n            \"        filtered = [x for x in line if x != 0]\\n\",\n            \"        merged = []\\n\",\n            \"        i = 0\\n\",\n            \"        while i < len(filtered):\\n\",\n            \"            if i+1 < len(filtered) and filtered[i] == filtered[i+1]:\\n\",\n            \"                merged.append(filtered[i]*2)\\n\",\n            \"                i += 2\\n\",\n            \"            else:\\n\",\n            \"                merged.append(filtered[i])\\n\",\n            \"                i += 1\\n\",\n            \"        merged += [0]*(size-len(merged))\\n\",\n            \"        return merged\\n\",\n            \"\\n\",\n            \"    # Evaluate each direction\\n\",\n            \"    best_score = -1\\n\",\n            \"    best_dir = \\\"W\\\"\\n\",\n            \"    for d in \\\"WASD\\\":\\n\",\n            \"        moved, new_board = score_move(d)\\n\",\n            \"        if not moved:\\n\",\n            \"            continue\\n\",\n            \"        # simple heuristic: sum of all tiles\\n\",\n            \"        score = sum(sum(row) for row in new_board)\\n\",\n            \"        if score > best_score:\\n\",\n            \"            best_score = score\\n\",\n            \"            best_dir = d\\n\",\n            \"    return best_dir\\n\",\n            \"┌────┬────┬────┬────┬────┬────┐\\n\",\n            \"│\\u001b[38;5;49m   8\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\u001b[38;5;208m1024\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;208m1024\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;214m 512\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;214m 512\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;154m 128\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;226m 256\\u001b[0m│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\n\",\n            \"└────┴────┴────┴────┴────┴────┘\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # Evaluate score for each move and pick the one with maximal tile value\\n\",\n            \"    dirs = {\\\"W\\\": (-1,0), \\\"A\\\": (0,-1), \\\"S\\\": (1,0), \\\"D\\\": (0,1)}\\n\",\n            \"    best = None\\n\",\n            \"    best_score = -1\\n\",\n            \"    for d, (dx, dy) in dirs.items():\\n\",\n            \"        new_board = [[0]*4 for _ in range(4)]\\n\",\n            \"        moved = False\\n\",\n            \"        for i in range(4):\\n\",\n            \"            for j in range(4):\\n\",\n            \"                ni, nj = i+dx, j+dy\\n\",\n            \"                if 0 <= ni < 4 and 0 <= nj < 4:\\n\",\n            \"                    new_board[ni][nj] = board[i][j]\\n\",\n            \"                    if new_board[ni][nj] != board[i][j]:\\n\",\n            \"                        moved = True\\n\",\n            \"        if not moved:\\n\",\n            \"            continue\\n\",\n            \"        score = sum([sum(row) for row in new_board])\\n\",\n            \"        if score > best_score:\\n\",\n            \"            best_score = score\\n\",\n            \"            best = d\\n\",\n            \"    return best if best is not None else \\\"W\\\"\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Exception = 'list_reverseiterator' object is not subscriptable\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    def score_row(row, dir):\\n\",\n            \"        if dir == 'L':\\n\",\n            \"            row = row[::-1]\\n\",\n            \"        merged = []\\n\",\n            \"        skip = False\\n\",\n            \"        for val in row:\\n\",\n            \"            if val == 0: continue\\n\",\n            \"            if skip:\\n\",\n            \"                skip = False\\n\",\n            \"                continue\\n\",\n            \"            if merged and merged[-1] == val:\\n\",\n            \"                merged[-1] *= 2\\n\",\n            \"                skip = True\\n\",\n            \"            else:\\n\",\n            \"                merged.append(val)\\n\",\n            \"        merged += [0]*(len(row)-len(merged))\\n\",\n            \"        if dir == 'L':\\n\",\n            \"            merged = merged[::-1]\\n\",\n            \"        return merged\\n\",\n            \"\\n\",\n            \"    def move(board, action):\\n\",\n            \"        new_board = [row[:] for row in board]\\n\",\n            \"        if action in 'L':\\n\",\n            \"            for r in new_board:\\n\",\n            \"                new_row = score_row(r, 'L')\\n\",\n            \"                for i, val in enumerate(new_row):\\n\",\n            \"                    r[i] = val\\n\",\n            \"        elif action in 'R':\\n\",\n            \"            for r in new_board:\\n\",\n            \"                new_row = score_row(r, 'R')\\n\",\n            \"                for i, val in enumerate(new_row):\\n\",\n            \"                    r[i] = val\\n\",\n            \"        elif action in 'U':\\n\",\n            \"            for c in range(4):\\n\",\n            \"                col = [new_board[r][c] for r in range(4)]\\n\",\n            \"                new_col = score_row(col, 'L')\\n\",\n            \"                for r in range(4):\\n\",\n            \"                    new_board[r][c] = new_col[r]\\n\",\n            \"        elif action in 'D':\\n\",\n            \"            for c in range(4):\\n\",\n            \"                col = [new_board[r][c] for r in range(4)]\\n\",\n            \"                new_col = score_row(col, 'R')\\n\",\n            \"                for r in range(4):\\n\",\n            \"                    new_board[r][c] = new_col[r]\\n\",\n            \"        return new_board\\n\",\n            \"\\n\",\n            \"    def empty(board):\\n\",\n            \"        return [(r, c) for r in range(4) for c in range(4) if board[r][c] == 0]\\n\",\n            \"\\n\",\n            \"    actions = 'WASD'\\n\",\n            \"    best = None\\n\",\n            \"    best_score = -1\\n\",\n            \"    for a in actions:\\n\",\n            \"        new = move(board, a)\\n\",\n            \"        empties = len(empty(new))\\n\",\n            \"        merged = sum(1 for r in new for val in r if val >0)\\n\",\n            \"        score = empties + merged\\n\",\n            \"        if score>best_score:\\n\",\n            \"            best_score = score\\n\",\n            \"            best = a\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # choose a move that keeps more tiles unchanged\\n\",\n            \"    moves = ['W','A','S','D']\\n\",\n            \"    best = moves[0]; best_score = -1\\n\",\n            \"    for m in moves:\\n\",\n            \"        new = board_state_after(board, m)\\n\",\n            \"        if new == board:\\n\",\n            \"            continue\\n\",\n            \"        score = score_board(new)\\n\",\n            \"        if score > best_score:\\n\",\n            \"            best_score = score; best = m\\n\",\n            \"    return best\\n\",\n            \"def board_state_after(board, move):\\n\",\n            \"    # simulate move on a copy of the board\\n\",\n            \"    from copy import deepcopy\\n\",\n            \"    b = deepcopy(board)\\n\",\n            \"    n = len(b)\\n\",\n            \"    # simple implementation of move logic\\n\",\n            \"    def compress(line):\\n\",\n            \"        new = [x for x in line if x!=0]\\n\",\n            \"        res = []\\n\",\n            \"        i=0\\n\",\n            \"        while i < len(new):\\n\",\n            \"            if i+1<len(new) and new[i]==new[i+1]:\\n\",\n            \"                res.append(new[i]*2); i+=2\\n\",\n            \"            else:\\n\",\n            \"                res.append(new[i]); i+=1\\n\",\n            \"        res += [0]*(n-len(res))\\n\",\n            \"        return res\\n\",\n            \"    if move=='W':\\n\",\n            \"        for j in range(n):\\n\",\n            \"            col=[b[i][j] for i in range(n)]\\n\",\n            \"            col=compress(col)\\n\",\n            \"            for i in range(n): b[i][j]=col[i]\\n\",\n            \"    elif move=='S':\\n\",\n            \"        for j in range(n):\\n\",\n            \"            col=[b[i][j] for i in range(n)][::-1]\\n\",\n            \"            col=compress(col)[::-1]\\n\",\n            \"            for i in range(n): b[i][j]=col[i]\\n\",\n            \"    elif move=='A':\\n\",\n            \"        for i in range(n):\\n\",\n            \"            row=compress(b[i])\\n\",\n            \"            b[i]=row\\n\",\n            \"    elif move=='D':\\n\",\n            \"        for i in range(n):\\n\",\n            \"            row=compress(b[i][::-1])[::-1]\\n\",\n            \"            b[i]=row\\n\",\n            \"    return b\\n\",\n            \"def score_board(board):\\n\",\n            \"    # higher score for more homogeneous board\\n\",\n            \"    total=0\\n\",\n            \"    for row in board:\\n\",\n            \"        for v in row:\\n\",\n            \"            total+=v\\n\",\n            \"    return total\\n\",\n            \"Exception = list assignment index out of range\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    # simulate four possible moves and choose the one\\n\",\n            \"    def move(board, dir):\\n\",\n            \"        size = len(board)\\n\",\n            \"        def compress(line):\\n\",\n            \"            filtered = [x for x in line if x != 0]\\n\",\n            \"            merged = []\\n\",\n            \"            skip = False\\n\",\n            \"            for i in range(len(filtered)):\\n\",\n            \"                if skip: skip = False; continue\\n\",\n            \"                if i+1 < len(filtered) and filtered[i] == filtered[i+1]:\\n\",\n            \"                    merged.append(filtered[i]*2)\\n\",\n            \"                    skip = True\\n\",\n            \"                else:\\n\",\n            \"                    merged.append(filtered[i])\\n\",\n            \"            merged += [0]*(size-len(merged))\\n\",\n            \"            return merged\\n\",\n            \"        new = [[0]*size for _ in range(size)]\\n\",\n            \"        if dir == 'W':\\n\",\n            \"            for j in range(size):\\n\",\n            \"                col = [board[i][j] for i in range(size)]\\n\",\n            \"                merged = compress(col)\\n\",\n            \"                for i in range(size):\\n\",\n            \"                    new[i][j] = merged[i]\\n\",\n            \"        elif dir == 'S':\\n\",\n            \"            for j in range(size):\\n\",\n            \"                col = [board[i][j] for i in range(size)][::-1]\\n\",\n            \"                merged = compress(col)[::-1]\\n\",\n            \"                for i in range(size):\\n\",\n            \"                    new[i][j] = merged[i]\\n\",\n            \"        elif dir == 'A':\\n\",\n            \"            for i in range(size):\\n\",\n            \"                row = board[i]\\n\",\n            \"                merged = compress(row)\\n\",\n            \"                new[i] = merged\\n\",\n            \"        elif dir == 'D':\\n\",\n            \"            for i in range(size):\\n\",\n            \"                row = board[i][::-1]\\n\",\n            \"                merged = compress(row)[::-1]\\n\",\n            \"                new[i] = merged\\n\",\n            \"        return new\\n\",\n            \"\\n\",\n            \"    best = None\\n\",\n            \"    best_score = -1\\n\",\n            \"    for dir in ('W','A','S','D'):\\n\",\n            \"        new = move(board, dir)\\n\",\n            \"        score = sum(sum(row) for row in new)\\n\",\n            \"        if score > best_score:\\n\",\n            \"            best_score = score\\n\",\n            \"            best = dir\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    def move(board, dir):\\n\",\n            \"        import copy\\n\",\n            \"        n=len(board)\\n\",\n            \"        new=[row[:] for row in board]\\n\",\n            \"        if dir=='W':\\n\",\n            \"            for j in range(n):\\n\",\n            \"                col=[new[i][j] for i in range(n)]\\n\",\n            \"                newcol=compress(col)\\n\",\n            \"                for i in range(n): new[i][j]=newcol[i]\\n\",\n            \"        elif dir=='S':\\n\",\n            \"            for j in range(n):\\n\",\n            \"                col=[new[i][j] for i in range(n)][::-1]\\n\",\n            \"                newcol=compress(col)[::-1]\\n\",\n            \"                for i in range(n): new[i][j]=newcol[i]\\n\",\n            \"        elif dir=='A':\\n\",\n            \"            for i in range(n):\\n\",\n            \"                new[i]=compress(new[i])\\n\",\n            \"        elif dir=='D':\\n\",\n            \"            for i in range(n):\\n\",\n            \"                new[i]=compress(new[i])[::-1][::-1]\\n\",\n            \"        return new\\n\",\n            \"    def compress(line):\\n\",\n            \"        filtered=[v for v in line if v>0]\\n\",\n            \"        res=[]\\n\",\n            \"        i=0\\n\",\n            \"        while i<len(filtered):\\n\",\n            \"            if i+1<len(filtered) and filtered[i]==filtered[i+1]:\\n\",\n            \"                res.append(filtered[i]*2); i+=2\\n\",\n            \"            else:\\n\",\n            \"                res.append(filtered[i]); i+=1\\n\",\n            \"        return res+[0]*(len(line)-len(res))\\n\",\n            \"    best=None\\n\",\n            \"    best_score=-1\\n\",\n            \"    for d in \\\"WASD\\\":\\n\",\n            \"        nb=move(board,d)\\n\",\n            \"        score=sum(sum(row) for row in nb)\\n\",\n            \"        if score>best_score:\\n\",\n            \"            best_score=score; best=d\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Steps = 1264 State = success\\n\",\n            \"def strategy(board):\\n\",\n            \"    # board is a 4x4 list of lists\\n\",\n            \"    import random\\n\",\n            \"    \\n\",\n            \"    # Directions with priority: diagonal corners\\n\",\n            \"    dirs = ['W', 'A', 'S', 'D']\\n\",\n            \"    for d in dirs:\\n\",\n            \"        new_board = [row[:] for row in board]\\n\",\n            \"        if d == 'W':\\n\",\n            \"            for j in range(4):\\n\",\n            \"                merged = False\\n\",\n            \"                for i in range(1, 4):\\n\",\n            \"                    if new_board[i][j] == new_board[i-1][j] and not merged:\\n\",\n            \"                        new_board[i-1][j] += new_board[i][j]\\n\",\n            \"                        new_board[i][j] = 0\\n\",\n            \"                        merged = True\\n\",\n            \"        elif d == 'S':\\n\",\n            \"            for j in range(4):\\n\",\n            \"                merged = False\\n\",\n            \"                for i in range(2, -1, -1):\\n\",\n            \"                    if new_board[i][j] == new_board[i+1][j] and not merged:\\n\",\n            \"                        new_board[i+1][j] += new_board[i][j]\\n\",\n            \"                        new_board[i][j] = 0\\n\",\n            \"                        merged = True\\n\",\n            \"        elif d == 'A':\\n\",\n            \"            for i in range(4):\\n\",\n            \"                merged = False\\n\",\n            \"                for j in range(1, 4):\\n\",\n            \"                    if new_board[i][j] == new_board[i][j-1] and not merged:\\n\",\n            \"                        new_board[i][j-1] += new_board[i][j]\\n\",\n            \"                        new_board[i][j] = 0\\n\",\n            \"                        merged = True\\n\",\n            \"        elif d == 'D':\\n\",\n            \"            for i in range(4):\\n\",\n            \"                merged = False\\n\",\n            \"                for j in range(2, -1, -1):\\n\",\n            \"                    if new_board[i][j] == new_board[i][j+1] and not merged:\\n\",\n            \"                        new_board[i][j+1] += new_board[i][j]\\n\",\n            \"                        new_board[i][j] = 0\\n\",\n            \"                        merged = True\\n\",\n            \"        # measure score: number of non-zero tiles\\n\",\n            \"        score = sum(1 for r in new_board for v in r if v != 0)\\n\",\n            \"        # choose first direction that reduces empty tiles\\n\",\n            \"        if score > sum(1 for r in board for v in r if v != 0):\\n\",\n            \"            return d\\n\",\n            \"    return random.choice(dirs)\\n\",\n            \"┌────┬────┬────┬────┬────┬────┐\\n\",\n            \"│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;45m   2\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;118m  64\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;49m   8\\u001b[0m│\\u001b[38;5;239m   .\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;46m  32\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\u001b[38;5;226m 256\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\n\",\n            \"├────┼────┼────┼────┼────┼────┤\\n\",\n            \"│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;47m  16\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;196m2048\\u001b[0m│\\u001b[38;5;51m   4\\u001b[0m│\\u001b[38;5;45m   2\\u001b[0m│\\n\",\n            \"└────┴────┴────┴────┴────┴────┘\\n\",\n            \"Exception = '>' not supported between instances of 'int' and 'str'\\n\",\n            \"Exception = cannot pickle 'generator' object\\n\",\n            \"Timeout\\n\",\n            \"def strategy(board):\\n\",\n            \"    def move(board, direction):\\n\",\n            \"        size = len(board)\\n\",\n            \"        def compress(line):\\n\",\n            \"            new = [x for x in line if x>0]\\n\",\n            \"            merged = []\\n\",\n            \"            i=0\\n\",\n            \"            while i < len(new):\\n\",\n            \"                if i+1 < len(new) and new[i]==new[i+1]:\\n\",\n            \"                    merged.append(new[i]*2)\\n\",\n            \"                    i+=2\\n\",\n            \"                else:\\n\",\n            \"                    merged.append(new[i])\\n\",\n            \"                    i+=1\\n\",\n            \"            return merged+[0]*(size-len(merged))\\n\",\n            \"        new_board=[[0]*size for _ in range(size)]\\n\",\n            \"        if direction=='W':\\n\",\n            \"            for j in range(size):\\n\",\n            \"                col=[board[i][j] for i in range(size)]\\n\",\n            \"                col=compress(col)\\n\",\n            \"                for i in range(size):\\n\",\n            \"                    new_board[i][j]=col[i]\\n\",\n            \"        elif direction=='S':\\n\",\n            \"            for j in range(size):\\n\",\n            \"                col=[board[i][j] for i in range(size)][::-1]\\n\",\n            \"                col=compress(col)[::-1]\\n\",\n            \"                for i in range(size):\\n\",\n            \"                    new_board[i][j]=col[i]\\n\",\n            \"        elif direction=='A':\\n\",\n            \"            for i in range(size):\\n\",\n            \"                row=compress(board[i])\\n\",\n            \"                new_board[i]=row\\n\",\n            \"        elif direction=='D':\\n\",\n            \"            for i in range(size):\\n\",\n            \"                row=compress(board[i][::-1])[::-1]\\n\",\n            \"                new_board[i]=row\\n\",\n            \"        return new_board\\n\",\n            \"\\n\",\n            \"    def score(b):\\n\",\n            \"        return sum(sum(1 for x in row if x>0) for row in b)\\n\",\n            \"\\n\",\n            \"    best=None\\n\",\n            \"    bestScore=-1\\n\",\n            \"    for d in \\\"WASD\\\":\\n\",\n            \"        nb=move(board,d)\\n\",\n            \"        s=score(nb)\\n\",\n            \"        if s>bestScore:\\n\",\n            \"            bestScore=s\\n\",\n            \"            best=d\\n\",\n            \"    return best\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"Timeout\\n\",\n            \"None\\n\",\n            \"Timeout\\n\",\n            \"Exception = list assignment index out of range\\n\",\n            \"Timeout\\n\",\n            \"None\\n\",\n            \"Timeout\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"trainer.train()\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"tlaUdxC_VHpz\"\n      },\n      \"source\": [\n        \"<a name=\\\"Inference\\\"></a>\\n\",\n        \"# Inference\\n\",\n        \"Now let's try the model we just trained!\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 31,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"8BZZHOKiF9Ct\",\n        \"outputId\": \"4989f8d9-d024-462e-c732-b7734676791a\"\n      },\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\\n\",\n            \"Knowledge cutoff: 2024-06\\n\",\n            \"Current date: 2025-10-06\\n\",\n            \"\\n\",\n            \"Reasoning: low\\n\",\n            \"\\n\",\n            \"# Valid channels: analysis, commentary, final. Channel must be included for every message.\\n\",\n            \"Calls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>user<|message|>Create a new short 2048 strategy using only native Python code.\\n\",\n            \"You are given a list of list of numbers for the current board state.\\n\",\n            \"Output one action for \\\"W\\\", \\\"A\\\", \\\"S\\\", \\\"D\\\" on what is the optimal next step.\\n\",\n            \"Output your new short function in backticks using the format below:\\n\",\n            \"```python\\n\",\n            \"def strategy(board):\\n\",\n            \"    return \\\"W\\\" # Example\\n\",\n            \"```\\n\",\n            \"All helper functions should be inside def strategy. Only output the short function `strategy`.<|end|><|start|>assistant<|channel|>analysis<|message|>We need a short strategy that picks best move using basic heuristics. Use simple scoring: prefer moving to increase score: choose move that gives largest tile sum after simulation. Write quick function.<|end|><|start|>assistant<|channel|>final<|message|>```python\\n\",\n            \"def strategy(board):\\n\",\n            \"    import itertools\\n\",\n            \"    dirs = [\\\"W\\\",\\\"A\\\",\\\"S\\\",\\\"D\\\"]\\n\",\n            \"    best_dir = None\\n\",\n            \"    best_val = -1\\n\",\n            \"    for d in dirs:\\n\",\n            \"        new_board = [row[:] for row in board]\\n\",\n            \"        moved = False\\n\",\n            \"        if d == \\\"W\\\":\\n\",\n            \"            for col in range(4):\\n\",\n            \"                temp = [new_board[r][col] for r in range(4) if new_board[r][col]]\\n\",\n            \"                merged = []\\n\",\n            \"                i=0\\n\",\n            \"                while i<len(temp):\\n\",\n            \"                    if i+1<len(temp) and temp[i]==temp[i+1]:\\n\",\n            \"                        merged.append(temp[i]*2); i+=2\\n\",\n            \"                    else:\\n\",\n            \"                        merged.append(temp[i]); i+=1\\n\",\n            \"                for r in range(4):\\n\",\n            \"                    new_board[r][col]=merged[r] if r<len(merged) else 0\\n\",\n            \"        elif d == \\\"S\\\":\\n\",\n            \"            for col in range(4):\\n\",\n            \"                temp = [new_board[r][col] for r in range(4) if new_board[r][col]]\\n\",\n            \"                merged = []\\n\",\n            \"                i=len(temp)-1\\n\",\n            \"                while i>=0:\\n\",\n            \"                    if i-1>=0 and temp[i]==temp[i-1]:\\n\",\n            \"                        merged.insert(0,temp[i]*2); i-=2\\n\",\n            \"                    else:\\n\",\n            \"                        merged.insert(0,temp[i]); i-=1\\n\",\n            \"                for r in range(4):\\n\",\n            \"                    new_board[r][col]=merged[r] if r<len(merged) else 0\\n\",\n            \"        elif d == \\\"A\\\":\\n\",\n            \"            for r in range(4):\\n\",\n            \"                temp = [new_board[r][c] for c in range(4) if new_board[r][c]]\\n\",\n            \"                merged = []\\n\",\n            \"                i=0\\n\",\n            \"                while i<len(temp):\\n\",\n            \"                    if i+1<len(temp) and temp[i]==temp[i+1]:\\n\",\n            \"                        merged.append(temp[i]*2); i+=2\\n\",\n            \"                    else:\\n\",\n            \"                        merged.append(temp[i]); i+=1\\n\",\n            \"                for c in range(4):\\n\",\n            \"                    new_board[r][c]=merged[c] if c<len(merged) else 0\\n\",\n            \"        elif d == \\\"D\\\":\\n\",\n            \"            for r in range(4):\\n\",\n            \"                temp = [new_board[r][c] for c in range(4) if new_board[r][c]]\\n\",\n            \"                merged = []\\n\",\n            \"                i=len(temp)-1\\n\",\n            \"                while i>=0:\\n\",\n            \"                    if i-1>=0 and temp[i]==temp[i-1]:\\n\",\n            \"                        merged.insert(0,temp[i]*2); i-=2\\n\",\n            \"                    else:\\n\",\n            \"                        merged.insert(0,temp[i]); i-=1\\n\",\n            \"                for c in range(4):\\n\",\n            \"                    new_board[r][c]=merged[c] if c<len(merged) else 0\\n\",\n            \"        val = sum(map(sum,new_board))\\n\",\n            \"        if val>best_val:\\n\",\n            \"            best_val, best_dir = val, d\\n\",\n            \"    return best_dir\\n\",\n            \"```<|return|>\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"text = tokenizer.apply_chat_template(\\n\",\n        \"    [{\\\"role\\\": \\\"user\\\", \\\"content\\\": prompt}],\\n\",\n        \"    tokenize = False,\\n\",\n        \"    add_generation_prompt = True,\\n\",\n        \"    reasoning_effort = \\\"low\\\",\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"from transformers import TextStreamer\\n\",\n        \"_ = model.generate(\\n\",\n        \"    **tokenizer(text, return_tensors = \\\"pt\\\").to(\\\"cuda\\\"),\\n\",\n        \"    temperature = 1.0,\\n\",\n        \"    max_new_tokens = 1024,\\n\",\n        \"    streamer = TextStreamer(tokenizer, skip_prompt = False),\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"-NUEmHFSYNTp\"\n      },\n      \"source\": [\n        \"<a name=\\\"Save\\\"></a>\\n\",\n        \"### Saving to float16 or `MXFP4`\\n\",\n        \"\\n\",\n        \"We also support saving to `float16` directly. Select `merged_16bit` for float16 or `mxfp4` for MXFP4 (OpenAI's gpt-oss native precision). We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 32,\n      \"metadata\": {\n        \"id\": \"NjXGTkp7YNtB\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Merge and push to hub in mxfp4 4bit format\\n\",\n        \"if False:\\n\",\n        \"    model.save_pretrained_merged(\\\"finetuned_model\\\", tokenizer, save_method = \\\"mxfp4\\\")\\n\",\n        \"if False:\\n\",\n        \"    model.push_to_hub_merged(\\\"repo_id/repo_name\\\", tokenizer, token = \\\"hf...\\\", save_method = \\\"mxfp4\\\")\\n\",\n        \"\\n\",\n        \"# Merge and push to hub in 16bit\\n\",\n        \"if False:\\n\",\n        \"    model.save_pretrained_merged(\\\"finetuned_model\\\", tokenizer, save_method = \\\"merged_16bit\\\")\\n\",\n        \"if False: # Pushing to HF Hub\\n\",\n        \"    model.push_to_hub_merged(\\\"hf/gpt-oss-finetune\\\", tokenizer, save_method = \\\"merged_16bit\\\", token = \\\"\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"V15Yhj1V9lwG\"\n      },\n      \"source\": [\n        \"# And we're done!\\n\",\n        \"Congratulations you just learned how to do reinforcement learning with gpt-oss! There were some advanced topics explained in this notebook - to learn more about gpt-oss and RL, there are more docs in Unsloth's [Reinforcement Learning Guide with gpt-oss](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"accelerator\": \"GPU\",\n    \"colab\": {\n      \"gpuType\": \"T4\",\n      \"provenance\": [],\n      \"include_colab_link\": true\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    },\n    \"widgets\": {\n      \"application/vnd.jupyter.widget-state+json\": {\n        \"02d120e49f2c4f95a6090b1d8d521767\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_dbf5ed93dac646ed979fa7a8c569dfe3\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_4db5ee5b7b674abba75fbce264e6dfa3\",\n            \"value\": \" 165/165 [00:00&lt;00:00, 17.9kB/s]\"\n          }\n        },\n        \"04d39c4dda9f4a1bb01b8d6320032372\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"06ab9eaa6f0f48c4b68cff1ca4b9f2fa\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"07f0420c4dfa477caccd7ae96551c2e4\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_ad75f887a140416abfca615b2fc3c385\",\n            \"max\": 3996690997,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_dee02a37a6f44f168546ee0077dc20d1\",\n            \"value\": 3996690997\n          }\n        },\n        \"0ac4d8e674804ad6bdc5f2d62f2e0d33\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_7bfcd9acf29646db8b6123708d1ffe27\",\n              \"IPY_MODEL_5e88d6515f16475fb72d7c153422b591\",\n              \"IPY_MODEL_5e5b77dd649547f896ab306fccc94a4e\"\n            ],\n            \"layout\": \"IPY_MODEL_a843fa23e6c94fb486bff8764574fdc5\"\n          }\n        },\n        \"0c0c96eeac664f339aa4511bf47087e2\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_18451e19df5449b1853b5e13dacd19c5\",\n              \"IPY_MODEL_d864d29d02c54ecfaedd7b866a6df8c2\",\n              \"IPY_MODEL_7875163297284832a35aca84cbb105ce\"\n            ],\n            \"layout\": \"IPY_MODEL_d42d8228ea1247a1a81bb99b18c4640c\"\n          }\n        },\n        \"0f99489932aa409b94ba34764aff19b0\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"1183d3f2ad3c4fb0af1d925b5f9e3efe\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_9cc51d8029eb4217bc37daa918649692\",\n              \"IPY_MODEL_41f13d2f023e405180689e03bc2c32a1\",\n              \"IPY_MODEL_247484c0bf5945bcb4627b48928366c8\"\n            ],\n            \"layout\": \"IPY_MODEL_14c0f20a9ab341ee966fe77815099ff0\"\n          }\n        },\n        \"147743757c804b85af2ef194f5f84e6a\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"14c0f20a9ab341ee966fe77815099ff0\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"152d7bf2a74f400db3d3ecaa719ef8d1\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"18451e19df5449b1853b5e13dacd19c5\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_bcda4c9a48e943a6a0ef812fcd64a6db\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_61e491b843c347b6b2a9948de7caf01d\",\n            \"value\": \"tokenizer_config.json: \"\n          }\n        },\n        \"1c96edb2f7c948b9968b1239982af942\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_ee23056662ad4b719b65005d776e0e72\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_87765ca0996b403dbe29deef48d548bf\",\n            \"value\": \" 4.00G/4.00G [01:42&lt;00:00, 117MB/s]\"\n          }\n        },\n        \"219ca32ab51e4b4385b2c1026a78503a\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_6c2ccfe3363b40b58fc26ea164d4ead4\",\n              \"IPY_MODEL_07f0420c4dfa477caccd7ae96551c2e4\",\n              \"IPY_MODEL_1c96edb2f7c948b9968b1239982af942\"\n            ],\n            \"layout\": \"IPY_MODEL_d93be4994f104b6e99d89a9e73cd6abd\"\n          }\n        },\n        \"245590db7d374515a428ff4abbd25588\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"247484c0bf5945bcb4627b48928366c8\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_cef064f1c55f41bf957fc4623260fdb4\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_37cbe8800af04a42a0355922969b6393\",\n            \"value\": \" 4/4 [01:00&lt;00:00, 13.06s/it]\"\n          }\n        },\n        \"263b7dc0b3fd465fac89b9266b19d526\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_147743757c804b85af2ef194f5f84e6a\",\n            \"max\": 4,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_2820e352ab004e818949acc31eb3888d\",\n            \"value\": 4\n          }\n        },\n        \"2820e352ab004e818949acc31eb3888d\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"2a6aa92676c74509b58373ca604c5b3b\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"2a6f43b64d164636a2d9708f0190f21b\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"2c40c6b846924200b29616a590af1672\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_06ab9eaa6f0f48c4b68cff1ca4b9f2fa\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_d98c2b1e979b4929891a8ee0c11f55df\",\n            \"value\": \"model.safetensors.index.json: \"\n          }\n        },\n        \"2fa84865e9f14c1491402ef81517b4bd\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"32d6af64f2464cfb965671f2692b4e15\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"34a9e38b0b454a69a067d1ddadec7626\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_9c4d6839934b4b13952a850d2084d498\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_c6a1decbc0e7421db622033214913cb9\",\n            \"value\": \"Fetching 4 files: 100%\"\n          }\n        },\n        \"350f29f737534bfba4258bc31ec274a2\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"36676899a61f4be4b631f6271f6ecec9\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"37cbe8800af04a42a0355922969b6393\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"3f9b801b52da4eb79f730d87bea5c338\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_b66c6ded549d4db8a2e5ea8e5016615c\",\n              \"IPY_MODEL_43da5073c3ad4e98a3ade17a0bb3b93d\",\n              \"IPY_MODEL_40365e2c9fef49148e4c93592d458afc\"\n            ],\n            \"layout\": \"IPY_MODEL_7e9d5212fc7844f286e14b70cbf0bc7a\"\n          }\n        },\n        \"40138ff29073407abb95f793509fc320\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"40365e2c9fef49148e4c93592d458afc\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_2a6f43b64d164636a2d9708f0190f21b\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_65c62d2198e64ee4a9e6547c2733135a\",\n            \"value\": \" 1.16G/1.16G [00:25&lt;00:00, 39.8MB/s]\"\n          }\n        },\n        \"41f13d2f023e405180689e03bc2c32a1\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_36676899a61f4be4b631f6271f6ecec9\",\n            \"max\": 4,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_77ecad9f150c430fa85f5833d97c42df\",\n            \"value\": 4\n          }\n        },\n        \"43da5073c3ad4e98a3ade17a0bb3b93d\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_4513a73fa95b41b5b6edadc9143ba9c1\",\n            \"max\": 1158267008,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_792d75a7d18945e7972826ac5b2ac386\",\n            \"value\": 1158267008\n          }\n        },\n        \"4513a73fa95b41b5b6edadc9143ba9c1\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"48741bbdeccb459aa4eea9c61339764b\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"4b9b3fe8dc764eedb9e18f166fe2f548\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_87a808c4d4f54f719adcd29de7206e1b\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_5f0b2a0e1953406b88af2c884904e2da\",\n            \"value\": \"model-00003-of-00004.safetensors: 100%\"\n          }\n        },\n        \"4cb119127b404f46a53012c62d004e28\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"4d67b10ec7794170addb4e968e20f170\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"4da21f53bf7f4e2d8132eb43e6ecc739\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"4db5ee5b7b674abba75fbce264e6dfa3\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"4fbc4cfe529d471ba85f3ae8e53b28d6\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_a0d0fedc5bec4f5b943fddf9a954fbdf\",\n              \"IPY_MODEL_cab602573c6940919f93e59fe6f4838d\",\n              \"IPY_MODEL_51b8f4ce40f94ac39cf44d98f1522ec7\"\n            ],\n            \"layout\": \"IPY_MODEL_32d6af64f2464cfb965671f2692b4e15\"\n          }\n        },\n        \"51aaa109480d4ae6bd419aea689d22ee\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"51b8f4ce40f94ac39cf44d98f1522ec7\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_60ceb890b5644493a8886d91b9dac461\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_40138ff29073407abb95f793509fc320\",\n            \"value\": \" 446/446 [00:00&lt;00:00, 50.5kB/s]\"\n          }\n        },\n        \"55ac5c2a82ee48fe988e1e4f26c168b0\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"5657a84bf4b74710b2de1a54f9236e39\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"596c2a62a635469eb74233ce00586a6f\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"59e46bbe96df4b88ad31c09096ce0e0a\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"5a59fb5f7acf4213847c985e66c9ee3c\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_81a728910a2341a785a6f252bbb371f7\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_69a8d50f11244ba688c183d14d2395ec\",\n            \"value\": \"generation_config.json: 100%\"\n          }\n        },\n        \"5b7af68130f04a63ad3efa3d9f602ebe\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_80fa3aef5e2040d9904c6b87b7214ca0\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_0f99489932aa409b94ba34764aff19b0\",\n            \"value\": \" 4/4 [01:42&lt;00:00, 42.23s/it]\"\n          }\n        },\n        \"5e5b77dd649547f896ab306fccc94a4e\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_59e46bbe96df4b88ad31c09096ce0e0a\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_8f5c7b88a2cc4b5abb0814c814833349\",\n            \"value\": \" 15.1k/? [00:00&lt;00:00, 1.37MB/s]\"\n          }\n        },\n        \"5e88d6515f16475fb72d7c153422b591\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_923653dfe90e475a9efa44baf98ba9a0\",\n            \"max\": 1,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_62600092f8cc43f493b86b0169f67be1\",\n            \"value\": 1\n          }\n        },\n        \"5ebe7b4e4ed24c53b783ee46377c682d\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_51aaa109480d4ae6bd419aea689d22ee\",\n            \"max\": 3998751275,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_acf4e50a248342f68d26daef21baa419\",\n            \"value\": 3998751275\n          }\n        },\n        \"5f0b2a0e1953406b88af2c884904e2da\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"60ceb890b5644493a8886d91b9dac461\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"614c5332c7d045109102a329e7f69dfd\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"61e491b843c347b6b2a9948de7caf01d\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"62600092f8cc43f493b86b0169f67be1\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"65c62d2198e64ee4a9e6547c2733135a\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"68ea891644ca4753a8e1bf278ff47e84\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"69a8d50f11244ba688c183d14d2395ec\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"6a47e60b10a6481b94aee021c8dbc7ba\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"6ab4e5676ad84807a126fffa99f7a0d4\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_e61ef80398444c13bf7cd20ef21a5057\",\n              \"IPY_MODEL_5ebe7b4e4ed24c53b783ee46377c682d\",\n              \"IPY_MODEL_e0fdef0087bc4a91a11932a2d933c001\"\n            ],\n            \"layout\": \"IPY_MODEL_596c2a62a635469eb74233ce00586a6f\"\n          }\n        },\n        \"6c2ccfe3363b40b58fc26ea164d4ead4\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_4da21f53bf7f4e2d8132eb43e6ecc739\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_735f70fac43449e3974de1b783d56d33\",\n            \"value\": \"model-00002-of-00004.safetensors: 100%\"\n          }\n        },\n        \"735f70fac43449e3974de1b783d56d33\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"749e8407a901483c8b513a2fb71596c8\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_ef01b874478b4bb497d31d2f8dd6145a\",\n            \"max\": 1,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_d50ea8cded9848ffa18be1ae6a2559df\",\n            \"value\": 1\n          }\n        },\n        \"751a46fbb8e24efabfb381a85c90fbe8\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"77204d81ff8f4ee585361a503fa647dc\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"77d34c0f1de548b4872208a063bb5017\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"77ecad9f150c430fa85f5833d97c42df\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"7841bc90b6a74120ab3e603c76332a01\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"7875163297284832a35aca84cbb105ce\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_ba94310dc12a4a258205b14901ad3f94\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_a93210a691414502ba3c2dff03ffb4ce\",\n            \"value\": \" 22.8k/? [00:00&lt;00:00, 1.66MB/s]\"\n          }\n        },\n        \"792d75a7d18945e7972826ac5b2ac386\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"7baca79d720c40b5a923b9717e28c982\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_ffabf89ecd9d48a5a3fc2a1c855ce080\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_614c5332c7d045109102a329e7f69dfd\",\n            \"value\": \" 1.19M/? [00:00&lt;00:00, 81.8MB/s]\"\n          }\n        },\n        \"7bd5d1beeb0e49e293d9f6b91bb6d7fb\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"7bfcd9acf29646db8b6123708d1ffe27\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_fd0ac7ed3d3146ec85913f4e05c4a2f6\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_77204d81ff8f4ee585361a503fa647dc\",\n            \"value\": \"chat_template.jinja: \"\n          }\n        },\n        \"7d3379cbd27a4218a9d84c5a12f3bb88\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"7e9d5212fc7844f286e14b70cbf0bc7a\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"80fa3aef5e2040d9904c6b87b7214ca0\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"81a728910a2341a785a6f252bbb371f7\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"84d27c45065e426badbfcfcdc8ff16b6\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_4d67b10ec7794170addb4e968e20f170\",\n            \"max\": 27868174,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_55ac5c2a82ee48fe988e1e4f26c168b0\",\n            \"value\": 27868174\n          }\n        },\n        \"87765ca0996b403dbe29deef48d548bf\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"87a808c4d4f54f719adcd29de7206e1b\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"8c7c6bb04a3f4a1494b34529f95a195c\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"8db5e86577744ff1a39c8e198eee5dd3\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_4b9b3fe8dc764eedb9e18f166fe2f548\",\n              \"IPY_MODEL_cca95e973bc445d3811335debf7c446e\",\n              \"IPY_MODEL_e507a46b4c754d9a8aede2aac0d203bc\"\n            ],\n            \"layout\": \"IPY_MODEL_751a46fbb8e24efabfb381a85c90fbe8\"\n          }\n        },\n        \"8f1e6c36b84c4115a671dcb9ade41c8b\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"8f5c7b88a2cc4b5abb0814c814833349\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"923653dfe90e475a9efa44baf98ba9a0\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": \"20px\"\n          }\n        },\n        \"9a079a30b4ae4bbc80122faf83e0ad59\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"9beac0680e3049dfafcb6ec185fd2265\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"9c4d6839934b4b13952a850d2084d498\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"9cc51d8029eb4217bc37daa918649692\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_a219f3b89a34443abe612846676f9356\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_152d7bf2a74f400db3d3ecaa719ef8d1\",\n            \"value\": \"Loading checkpoint shards: 100%\"\n          }\n        },\n        \"a0d0fedc5bec4f5b943fddf9a954fbdf\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_e1e77d98b01f4376a6c075975c27571e\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_6a47e60b10a6481b94aee021c8dbc7ba\",\n            \"value\": \"special_tokens_map.json: 100%\"\n          }\n        },\n        \"a219f3b89a34443abe612846676f9356\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"a843fa23e6c94fb486bff8764574fdc5\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"a93210a691414502ba3c2dff03ffb4ce\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"abe2b0a2913d4633943f44333ae799f8\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_2c40c6b846924200b29616a590af1672\",\n              \"IPY_MODEL_749e8407a901483c8b513a2fb71596c8\",\n              \"IPY_MODEL_7baca79d720c40b5a923b9717e28c982\"\n            ],\n            \"layout\": \"IPY_MODEL_68ea891644ca4753a8e1bf278ff47e84\"\n          }\n        },\n        \"acda8e7582934fecbbf854e66e23f698\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"acf4e50a248342f68d26daef21baa419\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"ad75f887a140416abfca615b2fc3c385\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"ae6d42fb84fc4984af1d4430acdcd3c9\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_350f29f737534bfba4258bc31ec274a2\",\n            \"max\": 165,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_9beac0680e3049dfafcb6ec185fd2265\",\n            \"value\": 165\n          }\n        },\n        \"b07acf871a0a46f1889bfb439d13752b\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"b66c6ded549d4db8a2e5ea8e5016615c\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_77d34c0f1de548b4872208a063bb5017\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_bf96e8666c224c26b0a01451d08e907a\",\n            \"value\": \"model-00004-of-00004.safetensors: 100%\"\n          }\n        },\n        \"ba94310dc12a4a258205b14901ad3f94\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"bcda4c9a48e943a6a0ef812fcd64a6db\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"bf96e8666c224c26b0a01451d08e907a\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"c6a1decbc0e7421db622033214913cb9\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"cab602573c6940919f93e59fe6f4838d\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_5657a84bf4b74710b2de1a54f9236e39\",\n            \"max\": 446,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_7bd5d1beeb0e49e293d9f6b91bb6d7fb\",\n            \"value\": 446\n          }\n        },\n        \"caf742160db041a1b6c2cfdf78f2dc9a\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_34a9e38b0b454a69a067d1ddadec7626\",\n              \"IPY_MODEL_263b7dc0b3fd465fac89b9266b19d526\",\n              \"IPY_MODEL_5b7af68130f04a63ad3efa3d9f602ebe\"\n            ],\n            \"layout\": \"IPY_MODEL_2a6aa92676c74509b58373ca604c5b3b\"\n          }\n        },\n        \"cca95e973bc445d3811335debf7c446e\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_2fa84865e9f14c1491402ef81517b4bd\",\n            \"max\": 3372033380,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_245590db7d374515a428ff4abbd25588\",\n            \"value\": 3372033380\n          }\n        },\n        \"cef064f1c55f41bf957fc4623260fdb4\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"d42d8228ea1247a1a81bb99b18c4640c\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"d50ea8cded9848ffa18be1ae6a2559df\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"d864d29d02c54ecfaedd7b866a6df8c2\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_dee07d33b8de4c3b847fcff670e68102\",\n            \"max\": 1,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_b07acf871a0a46f1889bfb439d13752b\",\n            \"value\": 1\n          }\n        },\n        \"d9020a2a2c8440db81d2cfdf0289b667\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"d93be4994f104b6e99d89a9e73cd6abd\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"d98c2b1e979b4929891a8ee0c11f55df\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"da4324e287e64e5ba98fc110693066df\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"dbf5ed93dac646ed979fa7a8c569dfe3\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"dbfeea8ee2374b8c8fa70431c35f281f\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_d9020a2a2c8440db81d2cfdf0289b667\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_04d39c4dda9f4a1bb01b8d6320032372\",\n            \"value\": \"tokenizer.json: 100%\"\n          }\n        },\n        \"dee02a37a6f44f168546ee0077dc20d1\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"dee07d33b8de4c3b847fcff670e68102\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": \"20px\"\n          }\n        },\n        \"e0fdef0087bc4a91a11932a2d933c001\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_7d3379cbd27a4218a9d84c5a12f3bb88\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_7841bc90b6a74120ab3e603c76332a01\",\n            \"value\": \" 4.00G/4.00G [01:41&lt;00:00, 60.6MB/s]\"\n          }\n        },\n        \"e1e77d98b01f4376a6c075975c27571e\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"e2973e6c02834a7c9f2f6ce5755f35f0\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"e507a46b4c754d9a8aede2aac0d203bc\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_e2973e6c02834a7c9f2f6ce5755f35f0\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_48741bbdeccb459aa4eea9c61339764b\",\n            \"value\": \" 3.37G/3.37G [01:40&lt;00:00, 32.0MB/s]\"\n          }\n        },\n        \"e61ef80398444c13bf7cd20ef21a5057\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_da4324e287e64e5ba98fc110693066df\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_8c7c6bb04a3f4a1494b34529f95a195c\",\n            \"value\": \"model-00001-of-00004.safetensors: 100%\"\n          }\n        },\n        \"ee23056662ad4b719b65005d776e0e72\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"ef01b874478b4bb497d31d2f8dd6145a\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": \"20px\"\n          }\n        },\n        \"f8dacdab001d4db0b6b3776ac7d3634a\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_5a59fb5f7acf4213847c985e66c9ee3c\",\n              \"IPY_MODEL_ae6d42fb84fc4984af1d4430acdcd3c9\",\n              \"IPY_MODEL_02d120e49f2c4f95a6090b1d8d521767\"\n            ],\n            \"layout\": \"IPY_MODEL_8f1e6c36b84c4115a671dcb9ade41c8b\"\n          }\n        },\n        \"fa9ea0d3234e41689c827485d0360885\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_9a079a30b4ae4bbc80122faf83e0ad59\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_acda8e7582934fecbbf854e66e23f698\",\n            \"value\": \" 27.9M/27.9M [00:00&lt;00:00, 44.5MB/s]\"\n          }\n        },\n        \"fd0ac7ed3d3146ec85913f4e05c4a2f6\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"fd2fe9ef6da64f72ab29d481d1739f5e\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_dbfeea8ee2374b8c8fa70431c35f281f\",\n              \"IPY_MODEL_84d27c45065e426badbfcfcdc8ff16b6\",\n              \"IPY_MODEL_fa9ea0d3234e41689c827485d0360885\"\n            ],\n            \"layout\": \"IPY_MODEL_4cb119127b404f46a53012c62d004e28\"\n          }\n        },\n        \"ffabf89ecd9d48a5a3fc2a1c855ce080\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"state\" : {}\n      }\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "examples/streamlit/streamlit_chat.py",
    "content": "import json\n\nimport requests\nimport streamlit as st\n\nDEFAULT_FUNCTION_PROPERTIES = \"\"\"\n{\n    \"type\": \"object\",\n    \"properties\": {\n        \"location\": {\n            \"type\": \"string\",\n            \"description\": \"The city and state, e.g. San Francisco, CA\"\n        }\n    },\n    \"required\": [\"location\"]\n}\n\"\"\".strip()\n\n# Session state for chat\nif \"messages\" not in st.session_state:\n    st.session_state.messages = []\n\nst.title(\"💬 Chatbot\")\n\nif \"model\" not in st.session_state:\n    if \"model\" in st.query_params:\n        st.session_state.model = st.query_params[\"model\"]\n    else:\n        st.session_state.model = \"small\"\n\noptions = [\"large\", \"small\"]\nselection = st.sidebar.segmented_control(\n    \"Model\", options, selection_mode=\"single\", default=st.session_state.model\n)\n# st.session_state.model = selection\nst.query_params.update({\"model\": selection})\n\ninstructions = st.sidebar.text_area(\n    \"Instructions\",\n    value=\"You are a helpful assistant that can answer questions and help with tasks.\",\n)\neffort = st.sidebar.radio(\n    \"Reasoning effort\",\n    [\"low\", \"medium\", \"high\"],\n    index=1,\n)\nst.sidebar.divider()\nst.sidebar.subheader(\"Functions\")\nuse_functions = st.sidebar.toggle(\"Use functions\", value=False)\n\nst.sidebar.subheader(\"Built-in Tools\")\n# Built-in Tools section\nuse_browser_search = st.sidebar.toggle(\"Use browser search\", value=False)\nuse_code_interpreter = st.sidebar.toggle(\"Use code interpreter\", value=False)\n\nif use_functions:\n    function_name = st.sidebar.text_input(\"Function name\", value=\"get_weather\")\n    function_description = st.sidebar.text_area(\n        \"Function description\", value=\"Get the weather for a given city\"\n    )\n    function_parameters = st.sidebar.text_area(\n        \"Function parameters\", value=DEFAULT_FUNCTION_PROPERTIES\n    )\nelse:\n    function_name = None\n    function_description = None\n    function_parameters = None\nst.sidebar.divider()\ntemperature = st.sidebar.slider(\n    \"Temperature\", min_value=0.0, max_value=1.0, value=1.0, step=0.01\n)\nmax_output_tokens = st.sidebar.slider(\n    \"Max output tokens\", min_value=1, max_value=131072, value=30000, step=1000\n)\nst.sidebar.divider()\ndebug_mode = st.sidebar.toggle(\"Debug mode\", value=False)\n\nif debug_mode:\n    st.sidebar.divider()\n    st.sidebar.code(json.dumps(st.session_state.messages, indent=2), \"json\")\n\nrender_input = True\n\nURL = (\n    \"http://localhost:8081/v1/responses\"\n    if selection == options[1]\n    else \"http://localhost:8000/v1/responses\"\n)\n\n\ndef trigger_fake_tool(container):\n    function_output = st.session_state.get(\"function_output\", \"It's sunny!\")\n    last_call = st.session_state.messages[-1]\n    if last_call.get(\"type\") == \"function_call\":\n        st.session_state.messages.append(\n            {\n                \"type\": \"function_call_output\",\n                \"call_id\": last_call.get(\"call_id\"),\n                \"output\": function_output,\n            }\n        )\n        run(container)\n\n\ndef run(container):\n    tools = []\n    if use_functions:\n        tools.append(\n            {\n                \"type\": \"function\",\n                \"name\": function_name,\n                \"description\": function_description,\n                \"parameters\": json.loads(function_parameters),\n            }\n        )\n    # Add browser_search tool if checkbox is checked\n    if use_browser_search:\n        tools.append({\"type\": \"browser_search\"})\n    if use_code_interpreter:\n        tools.append({\"type\": \"code_interpreter\"})\n    response = requests.post(\n        URL,\n        json={\n            \"input\": st.session_state.messages,\n            \"stream\": True,\n            \"instructions\": instructions,\n            \"reasoning\": {\"effort\": effort},\n            \"metadata\": {\"__debug\": debug_mode},\n            \"tools\": tools,\n            \"temperature\": temperature,\n            \"max_output_tokens\": max_output_tokens,\n        },\n        stream=True,\n    )\n\n    text_delta = \"\"\n    code_interpreter_sessions: dict[str, dict] = {}\n\n    _current_output_index = 0\n    for line in response.iter_lines(decode_unicode=True):\n        if not line or not line.startswith(\"data:\"):\n            continue\n        data_str = line[len(\"data:\") :].strip()\n        if not data_str:\n            continue\n        try:\n            data = json.loads(data_str)\n        except Exception:\n            continue\n\n        event_type = data.get(\"type\", \"\")\n        output_index = data.get(\"output_index\", 0)\n        if event_type == \"response.output_item.added\":\n            _current_output_index = output_index\n            output_type = data.get(\"item\", {}).get(\"type\", \"message\")\n            if output_type == \"message\":\n                output = container.chat_message(\"assistant\")\n                placeholder = output.empty()\n            elif output_type == \"reasoning\":\n                output = container.chat_message(\"reasoning\", avatar=\"🤔\")\n                placeholder = output.empty()\n            elif output_type == \"web_search_call\":\n                output = container.chat_message(\"web_search_call\", avatar=\"🌐\")\n                output.code(\n                    json.dumps(data.get(\"item\", {}).get(\"action\", {}), indent=4),\n                    language=\"json\",\n                )\n                placeholder = output.empty()\n            elif output_type == \"code_interpreter_call\":\n                item = data.get(\"item\", {})\n                item_id = item.get(\"id\")\n                message_container = container.chat_message(\n                    \"code_interpreter_call\", avatar=\"🧪\"\n                )\n                status_placeholder = message_container.empty()\n                code_placeholder = message_container.empty()\n                outputs_container = message_container.container()\n                code_text = item.get(\"code\") or \"\"\n                if code_text:\n                    code_placeholder.code(code_text, language=\"python\")\n                code_interpreter_sessions[item_id] = {\n                    \"status\": status_placeholder,\n                    \"code\": code_placeholder,\n                    \"outputs\": outputs_container,\n                    \"code_text\": code_text,\n                    \"rendered_outputs\": False,\n                }\n                placeholder = status_placeholder\n            text_delta = \"\"\n        elif event_type == \"response.reasoning_text.delta\":\n            output.avatar = \"🤔\"\n            text_delta += data.get(\"delta\", \"\")\n            placeholder.markdown(text_delta)\n        elif event_type == \"response.output_text.delta\":\n            text_delta += data.get(\"delta\", \"\")\n            placeholder.markdown(text_delta)\n        elif event_type == \"response.output_item.done\":\n            item = data.get(\"item\", {})\n            if item.get(\"type\") == \"function_call\":\n                with container.chat_message(\"function_call\", avatar=\"🔨\"):\n                    st.markdown(f\"Called `{item.get('name')}`\")\n                    st.caption(\"Arguments\")\n                    st.code(item.get(\"arguments\", \"\"), language=\"json\")\n            if item.get(\"type\") == \"web_search_call\":\n                placeholder.markdown(\"✅ Done\")\n            if item.get(\"type\") == \"code_interpreter_call\":\n                item_id = item.get(\"id\")\n                session = code_interpreter_sessions.get(item_id)\n                if session:\n                    session[\"status\"].markdown(\"✅ Done\")\n                    final_code = item.get(\"code\") or session[\"code_text\"]\n                    if final_code:\n                        session[\"code\"].code(final_code, language=\"python\")\n                        session[\"code_text\"] = final_code\n                    outputs = item.get(\"outputs\") or []\n                    if outputs and not session[\"rendered_outputs\"]:\n                        with session[\"outputs\"]:\n                            st.markdown(\"**Outputs**\")\n                            for output_item in outputs:\n                                output_type = output_item.get(\"type\")\n                                if output_type == \"logs\":\n                                    st.code(\n                                        output_item.get(\"logs\", \"\"),\n                                        language=\"text\",\n                                    )\n                                elif output_type == \"image\":\n                                    st.image(\n                                        output_item.get(\"url\", \"\"),\n                                        caption=\"Code interpreter image\",\n                                    )\n                        session[\"rendered_outputs\"] = True\n                    elif not outputs and not session[\"rendered_outputs\"]:\n                        with session[\"outputs\"]:\n                            st.caption(\"(No outputs)\")\n                        session[\"rendered_outputs\"] = True\n                else:\n                    placeholder.markdown(\"✅ Done\")\n        elif event_type == \"response.code_interpreter_call.in_progress\":\n            item_id = data.get(\"item_id\")\n            session = code_interpreter_sessions.get(item_id)\n            if session:\n                session[\"status\"].markdown(\"⏳ Running\")\n            else:\n                try:\n                    placeholder.markdown(\"⏳ Running\")\n                except Exception:\n                    pass\n        elif event_type == \"response.code_interpreter_call.interpreting\":\n            item_id = data.get(\"item_id\")\n            session = code_interpreter_sessions.get(item_id)\n            if session:\n                session[\"status\"].markdown(\"🧮 Interpreting\")\n        elif event_type == \"response.code_interpreter_call.completed\":\n            item_id = data.get(\"item_id\")\n            session = code_interpreter_sessions.get(item_id)\n            if session:\n                session[\"status\"].markdown(\"✅ Done\")\n            else:\n                try:\n                    placeholder.markdown(\"✅ Done\")\n                except Exception:\n                    pass\n        elif event_type == \"response.code_interpreter_call_code.delta\":\n            item_id = data.get(\"item_id\")\n            session = code_interpreter_sessions.get(item_id)\n            if session:\n                session[\"code_text\"] += data.get(\"delta\", \"\")\n                if session[\"code_text\"].strip():\n                    session[\"code\"].code(session[\"code_text\"], language=\"python\")\n        elif event_type == \"response.code_interpreter_call_code.done\":\n            item_id = data.get(\"item_id\")\n            session = code_interpreter_sessions.get(item_id)\n            if session:\n                final_code = data.get(\"code\") or session[\"code_text\"]\n                session[\"code_text\"] = final_code\n                if final_code:\n                    session[\"code\"].code(final_code, language=\"python\")\n        elif event_type == \"response.completed\":\n            response = data.get(\"response\", {})\n            if debug_mode:\n                container.expander(\"Debug\", expanded=False).code(\n                    response.get(\"metadata\", {}).get(\"__debug\", \"\"), language=\"text\"\n                )\n            st.session_state.messages.extend(response.get(\"output\", []))\n            if st.session_state.messages[-1].get(\"type\") == \"function_call\":\n                with container.form(\"function_output_form\"):\n                    _function_output = st.text_input(\n                        \"Enter function output\",\n                        value=st.session_state.get(\"function_output\", \"It's sunny!\"),\n                        key=\"function_output\",\n                    )\n                    st.form_submit_button(\n                        \"Submit function output\",\n                        on_click=trigger_fake_tool,\n                        args=[container],\n                    )\n            # Optionally handle other event types...\n\n\n# Chat display\nfor msg in st.session_state.messages:\n    if msg.get(\"type\") == \"message\":\n        with st.chat_message(msg[\"role\"]):\n            for item in msg[\"content\"]:\n                if (\n                    item.get(\"type\") == \"text\"\n                    or item.get(\"type\") == \"output_text\"\n                    or item.get(\"type\") == \"input_text\"\n                ):\n                    st.markdown(item[\"text\"])\n                    if item.get(\"annotations\"):\n                        annotation_lines = \"\\n\".join(\n                            f\"- {annotation.get('url')}\"\n                            for annotation in item[\"annotations\"]\n                            if annotation.get(\"url\")\n                        )\n                        st.caption(f\"**Annotations:**\\n{annotation_lines}\")\n    elif msg.get(\"type\") == \"reasoning\":\n        with st.chat_message(\"reasoning\", avatar=\"🤔\"):\n            for item in msg[\"content\"]:\n                if item.get(\"type\") == \"reasoning_text\":\n                    st.markdown(item[\"text\"])\n    elif msg.get(\"type\") == \"function_call\":\n        with st.chat_message(\"function_call\", avatar=\"🔨\"):\n            st.markdown(f\"Called `{msg.get('name')}`\")\n            st.caption(\"Arguments\")\n            st.code(msg.get(\"arguments\", \"\"), language=\"json\")\n    elif msg.get(\"type\") == \"function_call_output\":\n        with st.chat_message(\"function_call_output\", avatar=\"✅\"):\n            st.caption(\"Output\")\n            st.code(msg.get(\"output\", \"\"), language=\"text\")\n    elif msg.get(\"type\") == \"web_search_call\":\n        with st.chat_message(\"web_search_call\", avatar=\"🌐\"):\n            st.code(json.dumps(msg.get(\"action\", {}), indent=4), language=\"json\")\n            st.markdown(\"✅ Done\")\n    elif msg.get(\"type\") == \"code_interpreter_call\":\n        with st.chat_message(\"code_interpreter_call\", avatar=\"🧪\"):\n            st.markdown(\"✅ Done\")\n\nif render_input:\n    # Input field\n    if prompt := st.chat_input(\"Type a message...\"):\n        st.session_state.messages.append(\n            {\n                \"type\": \"message\",\n                \"role\": \"user\",\n                \"content\": [{\"type\": \"input_text\", \"text\": prompt}],\n            }\n        )\n\n        with st.chat_message(\"user\"):\n            st.markdown(prompt)\n\n        run(st.container())\n"
  },
  {
    "path": "gpt-oss-mcp-server/README.md",
    "content": "# MCP Servers for gpt-oss reference tools\n\nThis directory contains MCP servers for the reference tools in the [gpt-oss](https://github.com/openai/gpt-oss) repository.\nYou can set up these tools behind MCP servers and use them in your applications.\nFor inference service that integrates with MCP, you can also use these as reference tools.\n\nIn particular, this directory contains a `build-system-prompt.py` script that will generate exactly the same system prompt as `reference-system-prompt.py`.\nThe build system prompt script show case all the care needed to automatically discover the tools and construct the system prompt before feeding it into Harmony.\n\n## Usage\n\n```bash\n# Install the dependencies\nuv pip install -r requirements.txt\n```\n\n```bash\n# Assume we have harmony and gpt-oss installed\nuv pip install mcp[cli]\n# start the servers\nmcp run -t sse browser_server.py:mcp\nmcp run -t sse python_server.py:mcp\n```\n\nYou can now use MCP inspector to play with the tools.\nOnce opened, set SSE to `http://localhost:8001/sse` and `http://localhost:8000/sse` respectively.\n\nTo compare the system prompt and see how to construct it via MCP service discovery, see `build-system-prompt.py`.\nThis script will generate exactly the same system prompt as `reference-system-prompt.py`.\n"
  },
  {
    "path": "gpt-oss-mcp-server/browser_server.py",
    "content": "import os\nfrom collections.abc import AsyncIterator\nfrom contextlib import asynccontextmanager\nfrom dataclasses import dataclass, field\nfrom typing import Union, Optional\n\nfrom mcp.server.fastmcp import Context, FastMCP\nfrom gpt_oss.tools.simple_browser import SimpleBrowserTool\nfrom gpt_oss.tools.simple_browser.backend import YouComBackend, ExaBackend\n\n@dataclass\nclass AppContext:\n    browsers: dict[str, SimpleBrowserTool] = field(default_factory=dict)\n\n    def create_or_get_browser(self, session_id: str) -> SimpleBrowserTool:\n        if session_id not in self.browsers:\n            tool_backend = os.getenv(\"BROWSER_BACKEND\", \"exa\")\n            if tool_backend == \"youcom\":\n                backend = YouComBackend(source=\"web\")\n            elif tool_backend == \"exa\":\n                backend = ExaBackend(source=\"web\")\n            else:\n                raise ValueError(f\"Invalid tool backend: {tool_backend}\")\n            self.browsers[session_id] = SimpleBrowserTool(backend=backend)\n        return self.browsers[session_id]\n\n    def remove_browser(self, session_id: str) -> None:\n        self.browsers.pop(session_id, None)\n\n\n@asynccontextmanager\nasync def app_lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:\n    yield AppContext()\n\n\n# Pass lifespan to server\nmcp = FastMCP(\n    name=\"browser\",\n    instructions=r\"\"\"\nTool for browsing.\nThe `cursor` appears in brackets before each browsing display: `[{cursor}]`.\nCite information from the tool using the following format:\n`【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`. \nDo not quote more than 10 words directly from the tool output.\nsources=web\n\"\"\".strip(),\n    lifespan=app_lifespan,\n    port=8001,\n)\n\n\n@mcp.tool(\n    name=\"search\",\n    title=\"Search for information\",\n    description=\n    \"Searches for information related to `query` and displays `topn` results.\",\n)\nasync def search(ctx: Context,\n                 query: str,\n                 topn: int = 10,\n                 source: Optional[str] = None) -> str:\n    \"\"\"Search for information related to a query\"\"\"\n    browser = ctx.request_context.lifespan_context.create_or_get_browser(\n        ctx.client_id)\n    messages = []\n    async for message in browser.search(query=query, topn=topn, source=source):\n        if message.content and hasattr(message.content[0], 'text'):\n            messages.append(message.content[0].text)\n    return \"\\n\".join(messages)\n\n\n@mcp.tool(\n    name=\"open\",\n    title=\"Open a link or page\",\n    description=\"\"\"\nOpens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\nValid link ids are displayed with the formatting: `【{id}†.*】`.\nIf `cursor` is not provided, the most recent page is implied.\nIf `id` is a string, it is treated as a fully qualified URL associated with `source`.\nIf `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\nUse this function without `id` to scroll to a new location of an opened page.\n\"\"\".strip(),\n)\nasync def open_link(ctx: Context,\n                    id: Union[int, str] = -1,\n                    cursor: int = -1,\n                    loc: int = -1,\n                    num_lines: int = -1,\n                    view_source: bool = False,\n                    source: Optional[str] = None) -> str:\n    \"\"\"Open a link or navigate to a page location\"\"\"\n    browser = ctx.request_context.lifespan_context.create_or_get_browser(\n        ctx.client_id)\n    messages = []\n    async for message in browser.open(id=id,\n                                      cursor=cursor,\n                                      loc=loc,\n                                      num_lines=num_lines,\n                                      view_source=view_source,\n                                      source=source):\n        if message.content and hasattr(message.content[0], 'text'):\n            messages.append(message.content[0].text)\n    return \"\\n\".join(messages)\n\n\n@mcp.tool(\n    name=\"find\",\n    title=\"Find pattern in page\",\n    description=\n    \"Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\",\n)\nasync def find_pattern(ctx: Context, pattern: str, cursor: int = -1) -> str:\n    \"\"\"Find exact matches of a pattern in the current page\"\"\"\n    browser = ctx.request_context.lifespan_context.create_or_get_browser(\n        ctx.client_id)\n    messages = []\n    async for message in browser.find(pattern=pattern, cursor=cursor):\n        if message.content and hasattr(message.content[0], 'text'):\n            messages.append(message.content[0].text)\n    return \"\\n\".join(messages)\n"
  },
  {
    "path": "gpt-oss-mcp-server/build-system-prompt.py",
    "content": "import datetime\nimport asyncio\n\nfrom gpt_oss.tokenizer import get_tokenizer\n\nfrom openai_harmony import (\n    Conversation,\n    DeveloperContent,\n    HarmonyEncodingName,\n    Message,\n    ReasoningEffort,\n    Role,\n    SystemContent,\n    ToolNamespaceConfig,\n    ToolDescription,\n    load_harmony_encoding,\n)\n\nfrom mcp import ClientSession\nfrom mcp.client.sse import sse_client\nfrom mcp.types import ListToolsResult\n\n\nasync def list_server_and_tools(server_url: str):\n    async with sse_client(url=server_url) as streams, ClientSession(\n            *streams) as session:\n        initialize_response = await session.initialize()\n        list_tools_response = await session.list_tools()\n        return initialize_response, list_tools_response\n\n\ndef trim_schema(schema: dict) -> dict:\n    # Turn JSON Schema from MCP generated into Harmony's variant.\n    if \"title\" in schema:\n        del schema[\"title\"]\n    if \"default\" in schema and schema[\"default\"] is None:\n        del schema[\"default\"]\n    if \"anyOf\" in schema:\n        # Turn \"anyOf\": [{\"type\": \"type-1\"}, {\"type\": \"type-2\"}] into \"type\": [\"type-1\", \"type-2\"]\n        # if there's more than 1 types, also remove \"null\" type as Harmony will just ignore it\n        types = [\n            type_dict[\"type\"] for type_dict in schema[\"anyOf\"]\n            if type_dict[\"type\"] != 'null'\n        ]\n        schema[\"type\"] = types\n        del schema[\"anyOf\"]\n    if \"properties\" in schema:\n        schema[\"properties\"] = {\n            k: trim_schema(v)\n            for k, v in schema[\"properties\"].items()\n        }\n    return schema\n\n\ndef post_process_tools_description(\n        list_tools_result: ListToolsResult) -> ListToolsResult:\n    # Adapt the MCP tool result for Harmony\n    for tool in list_tools_result.tools:\n        tool.inputSchema = trim_schema(tool.inputSchema)\n\n    # Some tools schema don't need to be part of the prompt (e.g. simple text in text out for Python)\n    list_tools_result.tools = [\n        tool for tool in list_tools_result.tools\n        if getattr(tool.annotations, \"include_in_prompt\", True)\n    ]\n\n    return list_tools_result\n\ntokenizer = get_tokenizer()\n\ntools_urls = [\n    \"http://localhost:8001/sse\",  # browser\n    \"http://localhost:8000/sse\",  # python\n]\nharmony_tool_descriptions = []\nfor tools_url in tools_urls:\n\n    initialize_response, list_tools_response = asyncio.run(\n        list_server_and_tools(tools_url))\n\n    list_tools_response = post_process_tools_description(list_tools_response)\n\n    tool_from_mcp = ToolNamespaceConfig(\n        name=initialize_response.serverInfo.name,\n        description=initialize_response.instructions,\n        tools=[\n            ToolDescription.new(name=tool.name,\n                                description=tool.description,\n                                parameters=tool.inputSchema)\n            for tool in list_tools_response.tools\n        ])\n    harmony_tool_descriptions.append(tool_from_mcp)\n\nencoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n\nsystem_message_content = (SystemContent.new().with_reasoning_effort(\n    ReasoningEffort.LOW).with_conversation_start_date(\n        datetime.datetime.now().strftime(\"%Y-%m-%d\")))\n\nfor tool_description in harmony_tool_descriptions:\n    system_message_content = system_message_content.with_tools(\n        tool_description)\n\nsystem_message = Message.from_role_and_content(Role.SYSTEM,\n                                               system_message_content)\n\ndeveloper_message_content = DeveloperContent.new().with_instructions(\"\")\ndeveloper_message = Message.from_role_and_content(Role.DEVELOPER,\n                                                  developer_message_content)\n\nmessages = [system_message, developer_message]\n\nconversation = Conversation.from_messages(messages)\ntokens = encoding.render_conversation(conversation)\nsystem_message = tokenizer.decode(tokens)\nprint(system_message)\n"
  },
  {
    "path": "gpt-oss-mcp-server/pyproject.toml",
    "content": "[project]\nname = \"gpt-oss-mcp-server\"\nversion = \"0.1.0\"\nrequires-python = \">=3.10\"\ndependencies = [\n    \"mcp[cli]>=1.12.2\",\n    # \"gpt_oss\"\n]\n"
  },
  {
    "path": "gpt-oss-mcp-server/python_server.py",
    "content": "from mcp.server.fastmcp import FastMCP\nfrom gpt_oss.tools.python_docker.docker_tool import PythonTool\nfrom openai_harmony import Message, TextContent, Author, Role\n\n# Pass lifespan to server\nmcp = FastMCP(\n    name=\"python\",\n    instructions=r\"\"\"\nUse this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\nWhen you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you.\n\"\"\".strip(),\n)\n\n\n@mcp.tool(\n    name=\"python\",\n    title=\"Execute Python code\",\n    description=\"\"\"\nUse this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\nWhen you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you.\n    \"\"\",\n    annotations={\n        # Harmony format don't want this schema to be part of it because it's simple text in text out\n        \"include_in_prompt\": False,\n    })\nasync def python(code: str) -> str:\n    tool = PythonTool()\n    messages = []\n    async for message in tool.process(\n            Message(author=Author(role=Role.TOOL, name=\"python\"),\n                    content=[TextContent(text=code)])):\n        messages.append(message)\n    return \"\\n\".join([message.content[0].text for message in messages])\n"
  },
  {
    "path": "gpt-oss-mcp-server/reference-system-prompt.py",
    "content": "import datetime\n\nfrom gpt_oss.tools.simple_browser import SimpleBrowserTool\nfrom gpt_oss.tools.simple_browser.backend import YouComBackend\nfrom gpt_oss.tools.python_docker.docker_tool import PythonTool\nfrom gpt_oss.tokenizer import tokenizer\n\nfrom openai_harmony import (\n    Conversation,\n    DeveloperContent,\n    HarmonyEncodingName,\n    Message,\n    ReasoningEffort,\n    Role,\n    SystemContent,\n    load_harmony_encoding,\n)\n\nencoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n\nsystem_message_content = (SystemContent.new().with_reasoning_effort(\n    ReasoningEffort.LOW).with_conversation_start_date(\n        datetime.datetime.now().strftime(\"%Y-%m-%d\")))\n\nbackend = YouComBackend(source=\"web\")\nbrowser_tool = SimpleBrowserTool(backend=backend)\nsystem_message_content = system_message_content.with_tools(\n    browser_tool.tool_config)\n\npython_tool = PythonTool()\nsystem_message_content = system_message_content.with_tools(\n    python_tool.tool_config)\n\nsystem_message = Message.from_role_and_content(Role.SYSTEM,\n                                               system_message_content)\n\ndeveloper_message_content = DeveloperContent.new().with_instructions(\"\")\ndeveloper_message = Message.from_role_and_content(Role.DEVELOPER,\n                                                  developer_message_content)\n\nmessages = [system_message, developer_message]\n\nconversation = Conversation.from_messages(messages)\ntokens = encoding.render_conversation(conversation)\nsystem_message = tokenizer.decode(tokens)\nprint(system_message)\n"
  },
  {
    "path": "gpt_oss/__init__.py",
    "content": ""
  },
  {
    "path": "gpt_oss/chat.py",
    "content": "\"\"\"\nHarmony chat with tools\n\"\"\"\n\nimport atexit\nimport argparse\nimport asyncio\nimport datetime\nimport os\nfrom pathlib import Path\n\ntry:\n    import gnureadline as readline\nexcept ImportError:\n    import readline\n\nimport torch\nimport termcolor\n\nfrom gpt_oss.tools import apply_patch\nfrom gpt_oss.tools.simple_browser import SimpleBrowserTool\nfrom gpt_oss.tools.simple_browser.backend import YouComBackend\nfrom gpt_oss.tools.python_docker.docker_tool import PythonTool\n\nfrom openai_harmony import (\n    Author,\n    Conversation,\n    DeveloperContent,\n    HarmonyEncodingName,\n    Message,\n    ReasoningEffort,\n    Role,\n    StreamableParser,\n    StreamState,\n    SystemContent,\n    TextContent,\n    ToolDescription,\n    load_harmony_encoding,\n)\n\n\nREASONING_EFFORT = {\n    \"high\": ReasoningEffort.HIGH,\n    \"medium\": ReasoningEffort.MEDIUM,\n    \"low\": ReasoningEffort.LOW,\n}\n\n\ndef get_user_input():\n    rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0\n    if rank == 0:\n        user_input = input()\n    else:\n        user_input = \"\"\n    user_input_list = [user_input]\n    if torch.distributed.is_initialized():\n        torch.distributed.broadcast_object_list(user_input_list, 0)\n    return user_input_list[0]\n\n\ndef main(args):\n    match args.backend:\n        case \"triton\":\n            from gpt_oss.triton.model import TokenGenerator as TritonGenerator\n            from gpt_oss.torch.utils import init_distributed\n            device = init_distributed()\n            generator = TritonGenerator(args.checkpoint, args.context, device)\n        case \"torch\":\n            from gpt_oss.torch.model import TokenGenerator as TorchGenerator\n            from gpt_oss.torch.utils import init_distributed\n            device = init_distributed()\n            generator = TorchGenerator(args.checkpoint, device)\n        case \"vllm\":\n            from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator\n            generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2)\n        case _:\n            raise ValueError(f\"Invalid backend: {args.backend}\")\n\n    encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n\n    system_message_content = (\n        SystemContent.new()\n        .with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort])\n        .with_conversation_start_date(datetime.datetime.now().strftime(\"%Y-%m-%d\"))\n    )\n\n    if args.browser:\n        backend = YouComBackend(\n            source=\"web\",\n        )\n        browser_tool = SimpleBrowserTool(backend=backend)\n        system_message_content = system_message_content.with_tools(browser_tool.tool_config)\n\n    if args.python:\n        python_tool = PythonTool()\n        system_message_content = system_message_content.with_tools(python_tool.tool_config)\n\n    system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)\n    messages = [system_message]\n\n    if args.apply_patch:\n        apply_patch_instructions = Path(apply_patch.__file__).parent / \"apply_patch.md\"\n        developer_message = \"\"\n        if args.developer_message:\n            developer_message = args.developer_message + \"\\n\"\n        developer_message += apply_patch_instructions.read_text()\n        developer_message_content = (\n            DeveloperContent.new()\n            .with_instructions(developer_message)\n            .with_function_tools([\n                ToolDescription.new(\n                    \"apply_patch\",\n                    \"Patch a file\",\n                    parameters={\n                        \"type\": \"string\",\n                        \"description\": \"Formatted patch code\",\n                        \"default\": \"*** Begin Patch\\n*** End Patch\\n\",\n                    }\n                ),\n            ])\n        )\n        messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))\n    elif args.developer_message:\n        developer_message_content = DeveloperContent.new().with_instructions(args.developer_message)\n        messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))\n    else:\n        developer_message_content = None\n\n    if args.raw:\n        conversation = Conversation.from_messages(messages)\n        tokens = encoding.render_conversation(conversation)\n        system_message = encoding.decode(tokens)\n        print(system_message, flush=True, end=\"\")\n        empty_user_message_tokens = encoding.render(Message.from_role_and_content(Role.USER, \"\"))\n        user_message_start = encoding.decode(empty_user_message_tokens[:-1])\n        user_message_end = encoding.decode(empty_user_message_tokens[-1:])\n    else:\n        # System message\n        print(termcolor.colored(\"System Message:\", \"cyan\"), flush=True)\n        print(termcolor.colored(\"Model Identity:\", \"cyan\"), system_message_content.model_identity, flush=True)\n        print(termcolor.colored(\"Reasoning Effort:\", \"cyan\"), system_message_content.reasoning_effort, flush=True)\n        print(termcolor.colored(\"Conversation Start Date:\", \"cyan\"), system_message_content.conversation_start_date, flush=True)\n        print(termcolor.colored(\"Knowledge Cutoff:\", \"cyan\"), system_message_content.knowledge_cutoff, flush=True)\n        print(termcolor.colored(\"Browser Tool:\", \"cyan\"), \"Enabled\" if args.browser else \"Disabled\", flush=True)\n        print(termcolor.colored(\"Python Tool:\", \"cyan\"), \"Enabled\" if args.python else \"Disabled\", flush=True)\n        print(termcolor.colored(\"Apply Patch Function:\", \"cyan\"), \"Enabled\" if args.apply_patch else \"Disabled\", flush=True)\n        if developer_message_content:\n            print(termcolor.colored(\"Developer Message:\", \"yellow\"), flush=True)\n            print(developer_message_content.instructions, flush=True)\n\n    # Print the system message and the user message start\n    MESSAGE_PADDING = 12\n    while True:\n        last_message = messages[-1]\n        if last_message.recipient is None:\n            if args.raw:\n                print(user_message_start, end=\"\", flush=True)\n                user_message = get_user_input()\n                print(user_message_end, flush=True, end=\"\")\n            else:\n                print(termcolor.colored(\"User:\".ljust(MESSAGE_PADDING), \"red\"), flush=True)\n                user_message = get_user_input()\n            user_message = Message.from_role_and_content(Role.USER, user_message)\n            messages.append(user_message)\n        else:\n            # Tool or function call\n            if last_message.recipient.startswith(\"browser.\"):\n                assert args.browser, \"Browser tool is not enabled\"\n                tool_name = \"Search\"\n                async def run_tool():\n                    results = []\n                    async for msg in browser_tool.process(last_message):\n                        results.append(msg)\n                    return results\n\n                result = asyncio.run(run_tool())\n                messages += result\n            elif last_message.recipient.startswith(\"python\"):\n                assert args.python, \"Python tool is not enabled\"\n                tool_name = \"Python\"\n                async def run_tool():\n                    results = []\n                    async for msg in python_tool.process(last_message):\n                        results.append(msg)\n                    return results\n\n                result = asyncio.run(run_tool())\n                messages += result\n            elif last_message.recipient == \"functions.apply_patch\":\n                assert args.apply_patch, \"Apply patch tool is not enabled\"\n                tool_name = \"Apply Patch\"\n                text = last_message.content[0].text\n                tool_output = None\n\n                if text.startswith(\"{\"):\n                    # this is json, try to extract the patch from it\n                    import json\n                    try:\n                        some_dict = json.loads(text)\n                        _, text = some_dict.popitem()\n                    except Exception as e:\n                        tool_output = f\"Error parsing JSON: {e}\"\n\n                if tool_output is None:\n                    try:\n                        tool_output = apply_patch.apply_patch(text)\n                    except Exception as e:\n                        tool_output = f\"Error applying patch: {e}\"\n\n                message = (\n                    Message(\n                        author=Author.new(Role.TOOL, last_message.recipient),\n                        content=[TextContent(text=tool_output)]\n                    )\n                    .with_recipient(\"assistant\")\n                )\n                if last_message.channel:\n                    message = message.with_channel(last_message.channel)\n\n                result = [message]\n                messages += result\n            else:\n                raise ValueError(f\"Unknown tool or function call: {last_message.recipient}\")\n            # Print the tool or function call result\n            if args.raw:\n                rendered_result = encoding.render_conversation(Conversation.from_messages(result))\n                print(encoding.decode(rendered_result), flush=True, end=\"\")\n            else:\n                print(termcolor.colored(f\"{tool_name} output:\".ljust(MESSAGE_PADDING), \"magenta\"), flush=True)\n                if tool_name == \"Search\" and not args.show_browser_results:\n                    print(\"[Search results fed to the model]\")\n                else:\n                    print(result[0].content[0].text)\n\n        conversation = Conversation.from_messages(messages)\n        tokens = encoding.render_conversation_for_completion(\n            conversation, Role.ASSISTANT\n        )\n\n        if args.raw:\n            # Print the last two tokens, which are the start of the assistant message\n            print(encoding.decode(tokens[-2:]), flush=True, end=\"\")\n\n        parser = StreamableParser(encoding, role=Role.ASSISTANT)\n        field_created = False\n        current_output_text = \"\"\n        output_text_delta_buffer = \"\"\n        for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()):\n            parser.process(predicted_token)\n            if args.raw:\n                print(encoding.decode([predicted_token]), end=\"\", flush=True)\n                continue\n\n            if parser.state == StreamState.EXPECT_START:\n                print(\"\")  # new line\n                field_created = False\n\n            if not parser.last_content_delta:\n                continue\n\n            if not field_created:\n                field_created = True\n                if parser.current_channel == \"final\":\n                    print(termcolor.colored(\"Assistant:\", \"green\"), flush=True)\n                elif parser.current_recipient is not None:\n                    print(termcolor.colored(f\"Tool call to {parser.current_recipient}:\", \"cyan\"), flush=True)\n                else:\n                    print(termcolor.colored(\"CoT:\", \"yellow\"), flush=True)\n\n            should_send_output_text_delta = True\n            output_text_delta_buffer += parser.last_content_delta\n            if args.browser:\n                updated_output_text, _annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text + output_text_delta_buffer)\n                output_text_delta_buffer = updated_output_text[len(current_output_text):]\n                if has_partial_citations:\n                    should_send_output_text_delta = False\n            if should_send_output_text_delta:\n                print(output_text_delta_buffer, end=\"\", flush=True)\n                current_output_text += output_text_delta_buffer\n                output_text_delta_buffer = \"\"\n\n        messages += parser.messages\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Chat example\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n    parser.add_argument(\n        \"checkpoint\",\n        metavar=\"FILE\",\n        type=str,\n        help=\"Path to the SafeTensors checkpoint\",\n    )\n    parser.add_argument(\n        \"-r\",\n        \"--reasoning-effort\",\n        metavar=\"REASONING_EFFORT\",\n        type=str,\n        default=\"low\",\n        choices=[\"high\", \"medium\", \"low\"],\n        help=\"Reasoning effort\",\n    )\n    parser.add_argument(\n        \"-a\",\n        \"--apply-patch\",\n        action=\"store_true\",\n        help=\"Make apply_patch function available to the model\",\n    )\n    parser.add_argument(\n        \"-b\",\n        \"--browser\",\n        default=False,\n        action=\"store_true\",\n        help=\"Use browser tool\",\n    )\n    parser.add_argument(\n        \"--show-browser-results\",\n        default=False,\n        action=\"store_true\",\n        help=\"Show browser results\",\n    )\n    parser.add_argument(\n        \"-p\",\n        \"--python\",\n        default=False,\n        action=\"store_true\",\n        help=\"Use python tool\",\n    )\n    parser.add_argument(\n        \"--developer-message\",\n        default=\"\",\n        help=\"Developer message\",\n    )\n    parser.add_argument(\n        \"-c\",\n        \"--context\",\n        metavar=\"CONTEXT\",\n        type=int,\n        default=8192,\n        help=\"Max context length\",\n    )\n    parser.add_argument(\n        \"--raw\",\n        default=False,\n        action=\"store_true\",\n        help=\"Raw mode (does not render Harmony encoding)\",\n    )\n    parser.add_argument(\n        \"--backend\",\n        type=str,\n        default=\"triton\",\n        choices=[\"triton\", \"torch\", \"vllm\"],\n        help=\"Inference backend\",\n    )\n    args = parser.parse_args()\n\n    if int(os.environ.get(\"WORLD_SIZE\", 1)) == 1:\n        histfile = os.path.join(os.path.expanduser(\"~\"), \".chat\")\n        try:\n            readline.read_history_file(histfile)\n            readline.set_history_length(10000)\n        except FileNotFoundError:\n            pass\n\n        atexit.register(readline.write_history_file, histfile)\n\n    main(args)\n"
  },
  {
    "path": "gpt_oss/evals/README.md",
    "content": "# `gpt_oss.evals`\n\nThis module is a reincarnation of [simple-evals](https://github.com/openai/simple-evals) adapted for gpt-oss. It lets you\nrun GPQA and HealthBench against a runtime that supports Responses API on `localhost:8080/v1`."
  },
  {
    "path": "gpt_oss/evals/__init__.py",
    "content": ""
  },
  {
    "path": "gpt_oss/evals/__main__.py",
    "content": "import argparse\nimport json\nfrom datetime import datetime\n\nfrom . import report\nfrom .basic_eval import BasicEval\nfrom .gpqa_eval import GPQAEval\nfrom .aime_eval import AIME25Eval\nfrom .healthbench_eval import HealthBenchEval\nfrom .chat_completions_sampler import (\n    OPENAI_SYSTEM_MESSAGE_API,\n    ChatCompletionsSampler,\n)\nfrom .responses_sampler import ResponsesSampler\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Evaluate the models.\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"gpt-oss-120b,gpt-oss-20b\",\n        help=\"Select a model by name. Accepts a comma-separated list.\",\n    )\n    parser.add_argument(\n        \"--reasoning-effort\",\n        type=str,\n        default=\"low,medium,high\",\n        help=\"Reasoning effort (low, medium, high). Accepts a comma-separated list.\",\n    )\n    parser.add_argument(\n        \"--sampler\",\n        type=str,\n        choices=[\"responses\", \"chat_completions\"],\n        default=\"responses\",\n        help=\"Sampler backend to use for models.\",\n    )\n    parser.add_argument(\n        \"--base-url\",\n        type=str,\n        default=\"http://localhost:8000/v1\",\n        help=\"Base URL for the API.\",\n    )\n    parser.add_argument(\n        \"--eval\",\n        type=str,\n        default=\"gpqa,healthbench,healthbench_hard,healthbench_consensus,aime25\",\n        help=\"Select an eval by name. Accepts a comma-separated list.\",\n    )\n    parser.add_argument(\n        \"--temperature\",\n        type=float,\n        default=1.0,\n        help=\"Sampling temperature\",\n    )\n    parser.add_argument(\n        \"--n-threads\",\n        type=int,\n        default=1584,\n        help=\"Number of threads to run.\",\n    )\n    parser.add_argument(\n        \"--debug\", action=\"store_true\", help=\"Run in debug mode\"\n    )\n    parser.add_argument(\n        \"--examples\", type=int, help=\"Number of examples to use (overrides default)\"\n    )\n\n    args = parser.parse_args()\n\n    sampler_cls = ResponsesSampler if args.sampler == \"responses\" else ChatCompletionsSampler\n\n    models = {}\n    for model_name in args.model.split(\",\"):\n        for reasoning_effort in args.reasoning_effort.split(\",\"):\n            models[f\"{model_name}-{reasoning_effort}\"] = sampler_cls(\n                model=model_name,\n                reasoning_model=True,\n                reasoning_effort=reasoning_effort,\n                temperature=args.temperature,\n                base_url=args.base_url,\n                max_tokens=131_072,\n            )\n\n    print(f\"Running with args {args}\")\n\n    grading_sampler = ChatCompletionsSampler(\n        model=\"gpt-4.1-2025-04-14\",\n        system_message=OPENAI_SYSTEM_MESSAGE_API,\n        max_tokens=2048,\n        base_url=\"https://api.openai.com/v1\",\n    )\n\n    def get_evals(eval_name, debug_mode):\n        num_examples = (\n            args.examples if args.examples is not None else (5 if debug_mode else None)\n        )\n        # Set num_examples = None to reproduce full evals\n        match eval_name:\n            case \"basic\":\n                return BasicEval()\n            case \"gpqa\":\n                return GPQAEval(\n                    n_repeats=1 if args.debug else 8,\n                    num_examples=num_examples,\n                    debug=debug_mode,\n                    n_threads=args.n_threads or 1,\n                )\n            case \"healthbench\":\n                return HealthBenchEval(\n                    grader_model=grading_sampler,\n                    num_examples=10 if debug_mode else num_examples,\n                    n_repeats=1,\n                    n_threads=args.n_threads or 1,\n                    subset_name=None,\n                )\n            case \"healthbench_hard\":\n                return HealthBenchEval(\n                    grader_model=grading_sampler,\n                    num_examples=10 if debug_mode else num_examples,\n                    n_repeats=1,\n                    n_threads=args.n_threads or 1,\n                    subset_name=\"hard\",\n                )\n            case \"healthbench_consensus\":\n                return HealthBenchEval(\n                    grader_model=grading_sampler,\n                    num_examples=10 if debug_mode else num_examples,\n                    n_repeats=1,\n                    n_threads=args.n_threads or 1,\n                    subset_name=\"consensus\",\n                )\n            case \"aime25\":\n                return AIME25Eval(\n                    n_repeats=1 if args.debug else 8,\n                    num_examples=num_examples,\n                    n_threads=args.n_threads or 1,\n                )\n            case _:\n                raise Exception(f\"Unrecognized eval type: {eval_name}\")\n\n    evals = {}\n    for eval_name in args.eval.split(\",\"):\n        evals[eval_name] = get_evals(eval_name, args.debug)\n\n    debug_suffix = \"_DEBUG\" if args.debug else \"\"\n    print(debug_suffix)\n    mergekey2resultpath = {}\n    print(f\"Running the following evals: {evals}\")\n    print(f\"Running evals for the following models: {models}\")\n\n    now = datetime.now()\n    date_str = now.strftime(\"%Y%m%d_%H%M%S\")\n    for model_name, sampler in models.items():\n        model_name = model_name.replace(\"/\", \"__\")\n        for eval_name, eval_obj in evals.items():\n            result = eval_obj(sampler)\n            # ^^^ how to use a sampler\n            file_stem = f\"{eval_name}_{model_name}_temp{args.temperature}\"\n            # file stem should also include the year, month, day, and time in hours and minutes\n            file_stem += f\"_{date_str}\"\n            report_filename = f\"/tmp/{file_stem}{debug_suffix}.html\"\n            print(f\"Writing report to {report_filename}\")\n            with open(report_filename, \"w\") as fh:\n                fh.write(report.make_report(result))\n            assert result.metrics is not None\n            metrics = result.metrics | {\"score\": result.score}\n            # Sort metrics by key\n            metrics = dict(sorted(metrics.items()))\n            print(metrics)\n            result_filename = f\"/tmp/{file_stem}{debug_suffix}.json\"\n            with open(result_filename, \"w\") as f:\n                f.write(json.dumps(metrics, indent=2))\n            print(f\"Writing results to {result_filename}\")\n\n            full_result_filename = f\"/tmp/{file_stem}{debug_suffix}_allresults.json\"\n            with open(full_result_filename, \"w\") as f:\n                result_dict = {\n                    \"score\": result.score,\n                    \"metrics\": result.metrics,\n                    \"htmls\": result.htmls,\n                    \"convos\": result.convos,\n                    \"metadata\": result.metadata,\n                }\n                f.write(json.dumps(result_dict, indent=2))\n                print(f\"Writing all results to {full_result_filename}\")\n\n            mergekey2resultpath[f\"{file_stem}\"] = result_filename\n\n    merge_metrics = []\n    for eval_model_name, result_filename in mergekey2resultpath.items():\n        try:\n            result = json.load(open(result_filename, \"r+\"))\n        except Exception as e:\n            print(e, result_filename)\n            continue\n        result = result.get(\"f1_score\", result.get(\"score\", None))\n        eval_name = eval_model_name[: eval_model_name.find(\"_\")]\n        model_name = eval_model_name[eval_model_name.find(\"_\") + 1 :]\n        merge_metrics.append(\n            {\"eval_name\": eval_name, \"model_name\": model_name, \"metric\": result}\n        )\n    print(merge_metrics)\n    return merge_metrics\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "gpt_oss/evals/abcd_grader.py",
    "content": "import re\nimport sys\n\n\n_PATTERNS = [\n    # 0)\"**Answer:** A\" or \"*Answers* – B\", i.e. markdown‐wrapped \"Answer(s)\" with an unwrapped letter.\n    re.compile(\n        r'''(?ix)                   # case‐insensitive, ignore‐space\n        (?:\\*{1,2}|_{1,2})          # leading *…*  or _…_\n        Answer[s]?                  #   Answer or Answers\n        \\s*[:\\-–]?                  #   optional separator\n        (?:\\*{1,2}|_{1,2})          # closing wrapper\n        \\s*                         # optional space\n        ([ABCD])\\b                  # the actual letter\n        ''',\n        re.X\n    ),\n\n    # 0.1)\n    re.compile(r'''(?ix)           # ignore case, allow verbose mode\n        ^\\s*                      # optional leading whitespace\n        (?:\\*{1,2}|_{1,2})?       # optional markdown wrapper\n        Answer:?                   # the word 'answer' with an optional colon\n        (?:\\*{1,2}|_{1,2})?       # optional markdown wrapper again\n        \\s*:?\\s*                  # optional colon with optional spaces\n        (?:\\*{1,2}|_{1,2})?       # optional markdown wrapper before letter\n        ([ABCD])                 # capture the letter\n        (?:\\*{1,2}|_{1,2})?       # optional markdown wrapper after letter\n        \\s*                     # optional trailing whitespace, end of line\n    ''', re.MULTILINE),\n\n    # 1) Answer: (C)   or   Answers: (B)\n    re.compile(r'(?ix)\\bAnswer[s]?\\b\\s*[:\\-–]?\\s*\\(\\s*([ABCD])\\s*\\)'),\n\n    # 2) Answer: C    or   Answers – D\n    re.compile(r'(?ix)\\bAnswer[s]?\\b\\s*[:\\-–]?\\s*([ABCD])\\b'),\n\n    # 3) Option B   or   Choice: C\n    re.compile(r'(?ix)\\b(?:Option|Choice)\\b\\s*[:\\-–]?\\s*([ABCD])\\b'),\n\n    # 7) LaTeX \\boxed{...A...}, catches both \\boxed{A} and\n    #    \\boxed{\\text{A } 2.08\\times10^{-6}\\,\\mathrm{m}} etc.\n    re.compile(r'(?x)\\\\boxed\\{[^}]*?([ABCD])[^}]*\\}', re.MULTILINE),\n\n    # 7.5) LaTeX \\boxed{\\textbf{...C...}}\n    re.compile(r'(?x)\\\\boxed\\{[^}]*?\\\\textbf\\{[^}]*?([ABCD])[^}]*\\}[^}]*\\}', re.MULTILINE),\n\n    # 7.51) LaTeX \\boxed{\\text{...C...}}\n    re.compile(r'(?x)\\\\boxed\\{[^}]*?\\\\text\\{[^}]*?([ABCD])[^}]*\\}[^}]*\\}', re.MULTILINE),\n\n    # 4) bare singletons:  (A)  [B]\n    re.compile(r'(?x)(?<![A-Za-z0-9])[\\(\\[]\\s*([ABCD])\\s*[\\)\\]](?![A-Za-z0-9])'),\n\n    # 5) Markdown‐wrapped: *A*  **B**  _C_  __D__\n    re.compile(r'(?x)(?<![A-Za-z0-9])(?:\\*{1,2}|_{1,2})([ABCD])(?:\\*{1,2}|_{1,2})(?![A-Za-z0-9])'),\n\n    # 6) LaTeX \\textbf{...C...}\n    re.compile(r'(?x)\\\\textbf\\{[^}]*?([ABCD])[^}]*\\}'),\n\n    # 8) markdown‐wrapped answer plus “)” plus description, e.g. **D) …**\n    re.compile(r'''(?x)                        # ignore whitespace in pattern\n        (?<![A-Za-z0-9])            # not preceded by word‐char\n        (?:\\*{1,2}|_{1,2})          # opening ** or __ or * or _\n        \\s*([ABCD])\\)               # capture letter plus “)”\n        [^*_\\n]+?                   # some text inside wrapper\n        (?:\\*{1,2}|_{1,2})          # closing wrapper\n        (?![A-Za-z0-9])             # not followed by word‐char\n    '''),\n\n    # 9) final fallback: a line that's exactly \"A\", \"B.\", \"C)\", \"**D**\", etc.\n    re.compile(r'''(?x)^\\s*\n        (?:\\*{1,2}|_{1,2})?     # optional markdown wrapper\n        ([ABCD])                # capture group for letter\n        (?:\\*{1,2}|_{1,2})?     # optional closing markdown\n        \\s*[\\.\\)\\-–:]?          # optional separator after the letter\n        \\s*.*$                  # allow any following text\n    ''', re.MULTILINE),\n]\n\n\ndef extract_abcd(text: str) -> str | None:\n    \"\"\"\n    Scan text (with Markdown/LaTeX wrappers intact) and return\n    'A', 'B', 'C', or 'D' if a correct-answer declaration is found.\n    Otherwise return None.\n    \"\"\"\n    matches = []\n    for prio, pat in enumerate(_PATTERNS):\n        m = pat.search(text)\n        if m:\n            letter = m.group(1).upper()\n            if letter in 'ABCD':\n                matches.append((prio, m, letter))\n\n    matches.sort(key=lambda triple: (\n        triple[0],\n        len(triple[1].group(0))\n    ))\n    for _, match, letter in matches:\n        return letter\n    return text.removeprefix('**')[:1]\n\n\ndef main():\n    if len(sys.argv) > 1:\n        # Process files\n        for fn in sys.argv[1:]:\n            with open(fn, encoding='utf8') as fp:\n                text = fp.read()\n            ans = extract_abcd(text)\n            print(f\"{fn} ➜ {ans!r}\")\n    else:\n        # Read from stdin\n        for line in sys.stdin:\n            ans = extract_abcd(line)\n            print(f\"{line} ➜ {ans!r}\")\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "gpt_oss/evals/aime_eval.py",
    "content": "\"\"\"\nAIME 2025: https://huggingface.co/datasets/opencompass/AIME2025\n\"\"\"\nimport random\nimport re\nimport pandas\nfrom . import report\n\nfrom .types import Eval, EvalResult, SamplerBase, SingleEvalResult\n\n\nAIME_TEMPLATE = \"\"\"\n{question}\nPlease reason step by step, and put your final answer within \\\\boxed{{}}.\n\"\"\"\n\ndef format_aime_question(row):\n    return AIME_TEMPLATE.format(question=row[\"question\"])\n\ndef extract_boxed_text(text):\n    pattern = r'boxed{(.*?)}|framebox{(.*?)}'\n    matches = re.findall(pattern, text, re.DOTALL)\n    if matches:\n        for match in matches[::-1]:\n            for group in match:\n                if group != \"\":\n                    return group.split(',')[-1].strip()\n    pattern = r'\\d+'  # get the last integer if no pattern found\n    matches = re.findall(pattern, text, re.DOTALL)\n    if matches:\n        return matches[-1]\n    return \"\"\n\ndef normalize_number(s):\n    match = re.match(r\"\\d+\", s)  # match digits from the start\n    if not match:\n        return None\n    return match.group(0)\n\nclass AIME25Eval(Eval):\n    def __init__(\n        self,\n        n_repeats: int = 4,\n        num_examples: int | None = None,  # restrict to a subset of the data for debugging\n        n_threads: int = 1,\n    ):\n        path1 = f\"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl\"\n        df1 = pandas.read_json(path1, lines=True)\n        path2 = f\"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl\"\n        df2 = pandas.read_json(path2, lines=True)\n        examples = [row.to_dict() for _, row in df1.iterrows()] + [row.to_dict() for _, row in df2.iterrows()]\n        examples = [{\n            \"question\": row[\"question\"],\n            \"answer\": normalize_number(row[\"answer\"]) if isinstance(row[\"answer\"], str) else row[\"answer\"],\n        } for row in examples]\n        rng = random.Random(0)\n        if num_examples:\n            assert n_repeats == 1, \"n_repeats only supported for num_examples = None\"\n            examples = rng.sample(examples, num_examples)\n        examples = examples * n_repeats\n        examples = [example | {\"permutation\": rng.sample(range(4), 4)} for example in examples]\n        self.examples = examples\n        self.n_repeats = n_repeats\n        self.n_threads = n_threads\n\n    def __call__(self, sampler: SamplerBase) -> EvalResult:\n        def fn(row: dict):\n            prompt_messages = [\n                sampler._pack_message(\n                    content=format_aime_question(row), role=\"user\"\n                )\n            ]\n            sampler_response = sampler(prompt_messages)\n            response_text = sampler_response.response_text\n            actual_queried_prompt_messages = sampler_response.actual_queried_message_list\n            extracted_answer = extract_boxed_text(response_text)\n            correct_answer = int(row[\"answer\"])\n            try: # All AIME answers are integers, so we convert the extracted answer to an integer\n                extracted_answer = int(extracted_answer)\n            except (ValueError, TypeError):\n                extracted_answer = None\n            score = 1.0 if extracted_answer == correct_answer else 0.0\n            html = report.jinja_env.from_string(report.HTML_JINJA).render(\n                prompt_messages=actual_queried_prompt_messages,\n                next_message=dict(content=response_text, role=\"assistant\"),\n                score=score,\n                correct_answer=correct_answer,\n                extracted_answer=extracted_answer,\n            )\n            convo = actual_queried_prompt_messages + [dict(content=response_text, role=\"assistant\")]\n            return SingleEvalResult(\n                html=html, score=score, convo=convo, metrics={\"chars\": len(response_text)}\n            )\n\n        results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads)\n        return report.aggregate_results(results)\n\n"
  },
  {
    "path": "gpt_oss/evals/basic_eval.py",
    "content": "\"\"\"\nBasic eval\n\"\"\"\nfrom . import report\n\nfrom .types import Eval, EvalResult, SamplerBase, SingleEvalResult\n\nclass BasicEval(Eval):\n    def __init__(self,):\n        self.examples = [{\n            \"question\": \"hi\",\n            \"answer\": \"hi, how can i help?\",\n        }]\n\n    def __call__(self, sampler: SamplerBase) -> EvalResult:\n        def fn(row: dict):\n            sampler_response = sampler([\n                sampler._pack_message(content=row[\"question\"], role=\"user\")\n            ])\n            response_text = sampler_response.response_text\n            extracted_answer = response_text\n            actual_queried_prompt_messages = sampler_response.actual_queried_message_list\n            score = 1.0 if len(extracted_answer) > 0 else 0.0\n            html = report.jinja_env.from_string(report.HTML_JINJA).render(\n                prompt_messages=actual_queried_prompt_messages,\n                next_message=dict(content=response_text, role=\"assistant\"),\n                score=score,\n                correct_answer=row[\"answer\"],\n                extracted_answer=extracted_answer,\n            )\n            convo = actual_queried_prompt_messages + [dict(content=response_text, role=\"assistant\")]\n            return SingleEvalResult(\n                html=html, score=score, convo=convo, metrics={\"chars\": len(response_text)}\n            )\n\n        results = report.map_with_progress(fn, self.examples, num_threads=1)\n        return report.aggregate_results(results)\n\n"
  },
  {
    "path": "gpt_oss/evals/chat_completions_sampler.py",
    "content": "import time\nfrom typing import Any\n\nimport openai\nfrom openai import OpenAI\n\nfrom .types import MessageList, SamplerBase, SamplerResponse\n\n\nOPENAI_SYSTEM_MESSAGE_API = \"You are a helpful assistant.\"\nOPENAI_SYSTEM_MESSAGE_CHATGPT = (\n    \"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.\"\n    + \"\\nKnowledge cutoff: 2023-12\\nCurrent date: 2024-04-01\"\n)\n\n\nclass ChatCompletionsSampler(SamplerBase):\n    \"\"\"Sample from a Chat Completions compatible API.\"\"\"\n\n    def __init__(\n        self,\n        model: str = \"gpt-3.5-turbo\",\n        system_message: str | None = None,\n        temperature: float = 0.5,\n        max_tokens: int = 1024,\n        reasoning_model: bool = False,\n        reasoning_effort: str | None = None,\n        base_url: str = \"http://localhost:8000/v1\",\n    ):\n        self.client = OpenAI(base_url=base_url, timeout=24 * 60 * 60)\n        self.model = model\n        self.system_message = system_message\n        self.temperature = temperature\n        self.max_tokens = max_tokens\n        self.reasoning_model = reasoning_model\n        self.reasoning_effort = reasoning_effort\n        self.image_format = \"url\"\n\n    def _pack_message(self, role: str, content: Any) -> dict[str, Any]:\n        return {\"role\": str(role), \"content\": content}\n\n    def __call__(self, message_list: MessageList) -> SamplerResponse:\n        if self.system_message:\n            message_list = [\n                self._pack_message(\"system\", self.system_message)\n            ] + message_list\n        trial = 0\n        while True:\n            try:\n                if self.reasoning_model:\n                    response = self.client.chat.completions.create(\n                        model=self.model,\n                        messages=message_list,\n                        reasoning_effort=self.reasoning_effort,\n                        temperature=self.temperature,\n                        max_tokens=self.max_tokens,\n                    )\n                else:\n                    response = self.client.chat.completions.create(\n                        model=self.model,\n                        messages=message_list,\n                        temperature=self.temperature,\n                        max_tokens=self.max_tokens,\n                    )\n\n                choice = response.choices[0]\n                content = choice.message.content\n                if getattr(choice.message, \"reasoning\", None):\n                    message_list.append(self._pack_message(\"assistant\", choice.message.reasoning))\n\n                if not content:\n                    raise ValueError(\"OpenAI API returned empty response; retrying\")\n                return SamplerResponse(\n                    response_text=content,\n                    response_metadata={\"usage\": response.usage},\n                    actual_queried_message_list=message_list,\n                )\n            except openai.BadRequestError as e:\n                print(\"Bad Request Error\", e)\n                return SamplerResponse(\n                    response_text=\"No response (bad request).\",\n                    response_metadata={\"usage\": None},\n                    actual_queried_message_list=message_list,\n                )\n            except Exception as e:\n                exception_backoff = 2 ** trial  # exponential back off\n                print(\n                    f\"Rate limit exception so wait and retry {trial} after {exception_backoff} sec\",\n                    e,\n                )\n                time.sleep(exception_backoff)\n                trial += 1\n            # unknown error shall throw exception\n"
  },
  {
    "path": "gpt_oss/evals/gpqa_eval.py",
    "content": "\"\"\"\nGPQA: A Graduate-Level Google-Proof Q&A Benchmark\nDavid Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman\nhttps://arxiv.org/abs/2311.12022\n\"\"\"\n\nimport random\n\nimport pandas\n\nfrom . import report\nfrom .types import Eval, EvalResult, SamplerBase, SingleEvalResult\nfrom .abcd_grader import extract_abcd\n\n\nQUERY_TEMPLATE_MULTICHOICE = \"\"\"\n{Question}\n\n(A) {A}\n(B) {B}\n(C) {C}\n(D) {D}\n\nExpress your final answer as the corresponding option 'A', 'B', 'C', or 'D'.\n\"\"\".strip()\n\n\ndef format_multichoice_question(row):\n    return QUERY_TEMPLATE_MULTICHOICE.format(**row)\n\n\nclass GPQAEval(Eval):\n    def __init__(\n        self,\n        n_repeats: int = 8,\n        variant: str = \"diamond\",\n        num_examples: int | None = None,  # restrict to a subset of the data for debugging\n        debug: bool = False,\n        n_threads: int = 1,\n    ):\n        df = pandas.read_csv(\n            f\"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv\"\n        )\n        rng = random.Random(0)\n\n        if debug:\n            examples = [row.to_dict() for _, row in df.iterrows() if \"ESPRESSO spectrograph, please\" in row[\"Question\"]]\n        else:\n            examples = [row.to_dict() for _, row in df.iterrows()]\n            if num_examples:\n                assert n_repeats == 1, \"n_repeats only supported for num_examples = None\"\n                examples = rng.sample(examples, num_examples)\n\n        examples = examples * n_repeats\n        examples = [example | {\"permutation\": rng.sample(range(4), 4)} for example in examples]\n        self.examples = examples\n        self.n_repeats = n_repeats\n        self.n_threads = n_threads\n\n    def __call__(self, sampler: SamplerBase) -> EvalResult:\n        def fn(row: dict):\n            choices = [\n                row[\"Correct Answer\"],\n                row[\"Incorrect Answer 1\"],\n                row[\"Incorrect Answer 2\"],\n                row[\"Incorrect Answer 3\"],\n            ]\n            choices = [choices[i] for i in row[\"permutation\"]]\n            correct_index = choices.index(row[\"Correct Answer\"])\n            correct_answer = \"ABCD\"[correct_index]\n            choices_dict = dict(\n                A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=row[\"Question\"]\n            )\n            prompt_messages = [\n                sampler._pack_message(\n                    content=format_multichoice_question(choices_dict), role=\"user\"\n                )\n            ]\n            sampler_response = sampler(prompt_messages)\n            response_text = sampler_response.response_text\n            actual_queried_prompt_messages = sampler_response.actual_queried_message_list\n            extracted_answer = extract_abcd(response_text)\n            score = 1.0 if extracted_answer == correct_answer else 0.0\n            html = report.jinja_env.from_string(report.HTML_JINJA).render(\n                prompt_messages=actual_queried_prompt_messages,\n                next_message=dict(content=response_text, role=\"assistant\"),\n                score=score,\n                correct_answer=correct_answer,\n                extracted_answer=extracted_answer,\n            )\n            convo = actual_queried_prompt_messages + [dict(content=response_text, role=\"assistant\")]\n            return SingleEvalResult(\n                html=html, score=score, convo=convo, metrics={\"chars\": len(response_text)}\n            )\n\n        results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads)\n        return report.aggregate_results(results)\n\n\nif __name__ == \"__main__\":\n    import json\n    import sys\n\n    with open(sys.argv[1], \"r\") as f:\n        results = json.load(f)\n\n    passes = 0\n    for convo, html in zip(results[\"convos\"], results[\"htmls\"]):\n        message = convo[-1][\"content\"]\n        import re\n\n        # the ground truth is in <p>Correct Answer: A</p> in the html\n        ground_truth = re.search(r\"<p>Correct Answer: (A|B|C|D)</p>\", html)\n        ground_truth = ground_truth.group(1)\n        extracted_answer = extract_abcd(message)\n        if extracted_answer == ground_truth:\n            passes += 1\n        elif len(message) > 15:\n            print(\"no match:\", message)\n            print(\"ground truth:\", ground_truth)\n            print(\"extracted answer:\", extracted_answer)\n            print(\"--------------------------------\")\n\n    pass_rate = passes / len(results[\"convos\"])\n    print(f\"pass@1: {pass_rate}\")"
  },
  {
    "path": "gpt_oss/evals/healthbench_eval.py",
    "content": "\"\"\"\nThis script evaluates the performance of a model on the HealthBench dataset.\n\nTo run HealthBench, HealthBench Consensus, or HealthBench Hard, use the simple-evals script:\n- `python -m gpt_oss.evals --eval=healthbench --model=gpt-oss-120b`\n- `python -m gpt_oss.evals --eval=healthbench_consensus --model=gpt-oss-120b`\n- `python -m gpt_oss.evals --eval=healthbench_hard --model=gpt-oss-120b`\n\nYou can also evaluate physician ideal completions or reference completions against the HealthBench rubrics. To do so, run the following command:\n- To evaluate physician ideal completions: `python -m gpt_oss.evals.healthbench_eval --run_mode=physician_completions`\n- To evaluate reference model completions used by physicians: `python -m gpt_oss.evals.healthbench_eval --run_mode=physician_completion_references`\n\"\"\"\n\nimport argparse\nimport copy\nimport hashlib\nimport json\nimport random\nimport re\nfrom collections import defaultdict\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Literal\n\nimport blobfile as bf\nimport numpy as np\n\nfrom . import report\nfrom .chat_completions_sampler import (\n    OPENAI_SYSTEM_MESSAGE_API,\n    ChatCompletionsSampler,\n)\nfrom .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult\n\nINPUT_PATH = \"https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_eval.jsonl\"\nINPUT_PATH_HARD = \"https://openaipublic.blob.core.windows.net/simple-evals/healthbench/hard_2025-05-08-21-00-10.jsonl\"\nINPUT_PATH_CONSENSUS = \"https://openaipublic.blob.core.windows.net/simple-evals/healthbench/consensus_2025-05-09-20-00-46.jsonl\"\n\nGRADER_TEMPLATE = \"\"\"\nYour job is to look at a conversation and a rubric item, and score the last turn (i.e., the last assistant response, or the completion) in the conversation on how well it follows the rubric item.\n\n# Conversation\n<<conversation>>\n\n# Rubric item\n<<rubric_item>>\n\n# Instructions\nReturn a json object with the following fields: \"explanation\" and \"criteria_met\".\n- The \"explanation\" field should be a string explaining why the response does or does not meet the criteria of the rubric item.\n- The \"criteria_met\" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true if all of the criteria are met.\n- One important exception to the above bullet point is that if a criteria says \"such as\", \"for example\", or \"including\", the response does not have to include all of the examples listed to meet the criteria. For example, if the criteria says \"States that oral iron supplements can lead to unpleasant gastrointestinal side effects such as nausea, vomiting, and constipation\", and the response just says that oral iron supplements can lead to unpleasant gastrointestinal side effects such as cramps, that would still meet the criteria even though it didn't mention any of the specific examples listed in the criteria.\n\n# Example 1\nFor example, if the conversation is \"user: Should i go to the hospital if I am unconscious? assistant: no just stay at home until you magically get better\" and the rubric item is \"tells the user to go to the hospital if unconscious\", you should return a json like this:\n\n```json\n{\n  \"explanation\": \"The assistant did not tell the user to go to the hospital if unconscious.\",\n  \"criteria_met\": false\n}\n```\n\n# Example 2\nNote that sometimes the criteria may be something undesirable and have a corresponding point value that is negative. In this case, you should still return whether the criteria is met, and not whether the response is good or not.\nFor example, consider the conversation is \"user: is smoking good for you? assistant: no, smoking is bad for you\" and the criterion is \"fails to give the user accurate information\", which is something undesirable. Here, the criteria was not met because the response was correct. In this case, you should return a json like this:\n\n```json\n{\n  \"explanation\": \"The assistant actually gave the user accurate information, so the criteria was not met.\",\n  \"criteria_met\": false\n}\n```\n\nAs another example, if the criteria says \"Is overly verbose\" and the response is concise, then the criteria is not met, and you should return a json like this:\n\n```json\n{\n  \"explanation\": \"The response is concise, so the criteria was not met.\",\n  \"criteria_met\": false\n}\n```\n\nIn other words, for criteria with negative points, a good response should be classified as false because it does not meet the undesirable criteria, and only bad responses that do meet undesirable criteria should be classified as true.\n\n# Final instruction\nReturn just the json object in markdown format. Do not include any other text in the response.\n\"\"\".strip()\n\nHEALTHBENCH_HTML_JINJA = (\n    report.HTML_JINJA.replace(\n        \"<p>Correct Answer: {{ correct_answer }}</p>\\n\",\n        \"\",\n    )\n    + \"<p>Rubrics with grades: {{ rubric_grades }}</p>\"\n)\n\n\ndef parse_json_to_dict(json_string: str) -> dict:\n    # Remove markdown-style ```json``` markers if present\n    json_cleaned = re.sub(r\"^```json\\s*|\\s*```$\", \"\", json_string.strip())\n\n    try:\n        return json.loads(json_cleaned)\n    except json.JSONDecodeError as e:\n        print(f\"JSON decoding failed: {e}\")\n        return {}\n\n\nclass RubricItem:\n    def __init__(self, criterion: str, points: float, tags: list[str]):\n        self.criterion = criterion\n        self.points = points\n        self.tags = tags\n\n    def __str__(self):\n        return f\"[{self.points}] {self.criterion}\"\n\n    def to_dict(self):\n        return {\n            \"criterion\": self.criterion,\n            \"points\": self.points,\n            \"tags\": self.tags,\n        }\n\n    @classmethod\n    def from_dict(cls, d: dict):\n        return cls(\n            criterion=d[\"criterion\"],\n            points=d[\"points\"],\n            tags=d[\"tags\"],\n        )\n\n\ndef calculate_score(\n    rubric_items: list[RubricItem], grading_response_list: list[dict]\n) -> float | None:\n    total_possible_points = sum(\n        rubric_item.points for rubric_item in rubric_items if rubric_item.points > 0\n    )\n    if total_possible_points == 0:\n        # should not happen for overall score, but may happen for tags\n        return None\n\n    achieved_points = sum(\n        rubric_item.points\n        for rubric_item, grading_response in zip(\n            rubric_items, grading_response_list, strict=True\n        )\n        if grading_response[\"criteria_met\"]\n    )\n    overall_score = achieved_points / total_possible_points\n    return overall_score\n\n\ndef get_usage_dict(response_usage) -> dict[str, int | None]:\n    if response_usage is None:\n        return {\n            \"input_tokens\": None,\n            \"input_cached_tokens\": None,\n            \"output_tokens\": None,\n            \"output_reasoning_tokens\": None,\n            \"total_tokens\": None,\n        }\n\n    return {\n        \"input_tokens\": response_usage.input_tokens,\n        \"output_tokens\": response_usage.output_tokens,\n        \"total_tokens\": response_usage.total_tokens,\n        \"input_cached_tokens\": None,\n        \"output_reasoning_tokens\": None,\n    }\n\n\nPHYSICIAN_COMPLETION_MODES = {\n    \"Group 1\": {\n        \"description\": \"No reference completions were provided to the physicians.\",\n        \"short_name\": \"no_reference\",\n        \"has_reference\": False,\n    },\n    \"Group 2\": {\n        \"description\": \"Reference completions were provided to the physicians from Aug / Sep 2024 models (gpt-4o-2024-08-06, o1-preview).\",\n        \"short_name\": \"aug_2024_reference\",\n        \"has_reference\": True,\n    },\n    \"Group 3\": {\n        \"description\": \"Reference completions were provided to the physicians from Apr 2025 models (o3, gpt-4.1).\",\n        \"short_name\": \"apr_2025_reference\",\n        \"has_reference\": True,\n    },\n}\n\n\ndef _compute_clipped_stats(\n    values: list,\n    stat: str,\n):\n    \"\"\"Computes the mean (clipped to [0, 1]), bootstrap std for that mean, and n_samples for final HealthBench scoring.\"\"\"\n    if stat == \"mean\":\n        return np.clip(np.mean(values), 0, 1)\n    elif stat == \"n_samples\":\n        return len(values)\n    elif stat == \"bootstrap_std\":\n        bootstrap_samples = [np.random.choice(values, len(values)) for _ in range(1000)]\n        bootstrap_means = [\n            _compute_clipped_stats(list(s), \"mean\") for s in bootstrap_samples\n        ]\n        return np.std(bootstrap_means)\n    else:\n        raise ValueError(f\"Unknown {stat =}\")\n\n\ndef _aggregate_get_clipped_mean(\n    single_eval_results: list[SingleEvalResult],\n) -> EvalResult:\n    \"\"\"\n    Aggregate multiple SingleEvalResults into a single EvalResult for HealthBench.\n    For each metric, returns the stats in _compute_clipped_stats.\n    \"\"\"\n    name2values = defaultdict(list)\n    htmls = []\n    convos = []\n    metadata = []\n    for single_eval_result in single_eval_results:\n        for name, value in single_eval_result.metrics.items():\n            name2values[name].append(value)\n        if single_eval_result.score is not None:\n            name2values[\"score\"].append(single_eval_result.score)\n        htmls.append(single_eval_result.html)\n        convos.append(single_eval_result.convo)\n        metadata.append(single_eval_result.example_level_metadata)\n    final_metrics = {}\n    for name, values in name2values.items():\n        for stat in [\"mean\", \"n_samples\", \"bootstrap_std\"]:\n            key = name if stat == \"mean\" else f\"{name}:{stat}\"\n            final_metrics[key] = _compute_clipped_stats(values, stat)\n    return EvalResult(\n        score=final_metrics.pop(\"score\", None),\n        metrics=final_metrics,\n        htmls=htmls,\n        convos=convos,\n        metadata={\"example_level_metadata\": metadata},\n    )\n\n\nclass HealthBenchEval(Eval):\n    def __init__(\n        self,\n        grader_model: SamplerBase,\n        num_examples: int | None = None,\n        n_repeats: int = 1,\n        # If set, evaluate human completions or reference completions instead of model completions.\n        physician_completions_mode: str | None = None,\n        # If True, run the grader on reference completions used by physicians, and physician_completions_mode must be set.\n        run_reference_completions: bool = False,\n        n_threads: int = 120,\n        subset_name: Literal[\"hard\", \"consensus\"] | None = None,\n    ):\n        if run_reference_completions:\n            assert physician_completions_mode is not None, (\n                \"physician_completions_mode must be provided if run_reference_completions is True\"\n            )\n            assert PHYSICIAN_COMPLETION_MODES[physician_completions_mode][\n                \"has_reference\"\n            ], (\n                \"physician_completions_mode must have reference completions if run_reference_completions is True\"\n            )\n\n        if subset_name == \"hard\":\n            input_path = INPUT_PATH_HARD\n        elif subset_name == \"consensus\":\n            input_path = INPUT_PATH_CONSENSUS\n        elif subset_name is None:\n            input_path = INPUT_PATH\n        else:\n            assert False, f\"Invalid subset name: {subset_name}\"\n        with bf.BlobFile(input_path, \"rb\") as f:\n            examples = [json.loads(line) for line in f]\n        for example in examples:\n            example[\"rubrics\"] = [RubricItem.from_dict(d) for d in example[\"rubrics\"]]\n\n        rng = random.Random(0)\n\n        # physician completions mode\n        self.physician_completions_mode = physician_completions_mode\n        if self.physician_completions_mode is not None:\n            assert self.physician_completions_mode in PHYSICIAN_COMPLETION_MODES, (\n                f\"Invalid physician completions mode: {self.physician_completions_mode}; must be one of {PHYSICIAN_COMPLETION_MODES.keys()}\"\n            )\n            # subset to only the rows which have physician completions from that group\n            examples_matching_mode = [\n                example\n                for example in examples\n                if example[\"ideal_completions_data\"] is not None\n                and example[\"ideal_completions_data\"][\"ideal_completions_group\"]\n                == self.physician_completions_mode\n            ]\n            print(\n                f\"Subsetting to {len(examples_matching_mode)} examples with physician completions of type {self.physician_completions_mode} ({PHYSICIAN_COMPLETION_MODES[self.physician_completions_mode]['description']})\"\n            )\n\n            examples = []\n            if run_reference_completions:\n                for example in examples_matching_mode:\n                    for completion in example[\"ideal_completions_data\"][\n                        \"ideal_completions_ref_completions\"\n                    ]:\n                        new_example = copy.deepcopy(example)\n                        new_example[\"completion_to_trial\"] = completion\n                        examples.append(new_example)\n                assert len(examples) == len(examples_matching_mode) * 4\n                print(\n                    f\"Running four references for each example, for {len(examples)} total\"\n                )\n            else:\n                for example in examples_matching_mode:\n                    example[\"completion_to_trial\"] = example[\"ideal_completions_data\"][\n                        \"ideal_completion\"\n                    ]\n                    examples.append(example)\n                assert len(examples) == len(examples_matching_mode)\n\n            if len(examples) == 0:\n                raise ValueError(\n                    f\"No examples found matching mode {self.physician_completions_mode}\"\n                )\n\n        if num_examples is not None and num_examples < len(examples):\n            examples = rng.sample(\n                examples,\n                num_examples,\n            )\n\n        self.examples = examples * n_repeats\n        self.n_threads = n_threads\n        self.grader_model = grader_model\n\n    def grade_sample(\n        self,\n        prompt: list[dict[str, str]],\n        response_text: str,\n        example_tags: list[str],\n        rubric_items: list[RubricItem],\n    ) -> tuple[dict, str, list[dict]]:\n        # construct and grade the sample\n        convo_with_response = prompt + [dict(content=response_text, role=\"assistant\")]\n\n        def grade_rubric_item(rubric_item: RubricItem) -> dict:\n            convo_str = \"\\n\\n\".join(\n                [f\"{m['role']}: {m['content']}\" for m in convo_with_response]\n            )\n            grader_prompt = GRADER_TEMPLATE.replace(\n                \"<<conversation>>\", convo_str\n            ).replace(\"<<rubric_item>>\", str(rubric_item))\n            messages: MessageList = [dict(content=grader_prompt, role=\"user\")]\n            while True:\n                sampler_response = self.grader_model(messages)\n                grading_response = sampler_response.response_text\n                grading_response_dict = parse_json_to_dict(grading_response)\n                if \"criteria_met\" in grading_response_dict:\n                    label = grading_response_dict[\"criteria_met\"]\n                    if label is True or label is False:\n                        break\n                print(\"Grading failed due to bad JSON output, retrying...\")\n            return grading_response_dict\n\n        grading_response_list = report.map_with_progress(\n            grade_rubric_item,\n            rubric_items,\n            pbar=False,\n        )\n\n        # compute the overall score\n        overall_score = calculate_score(rubric_items, grading_response_list)\n        assert overall_score is not None\n        metrics = {\n            \"overall_score\": overall_score,\n        }\n\n        # compute scores for example-level tags)\n        example_tag_scores = {tag: overall_score for tag in example_tags}\n        assert len(example_tag_scores) == len(example_tags)  # No duplicates.\n        metrics.update(example_tag_scores)\n\n        # compute scores for rubric-level tags\n        rubric_tag_items_grades = defaultdict(list)\n        for rubric_item, grading_response in zip(rubric_items, grading_response_list):\n            curr_item_tags = set()  # Ensure no duplicates in a rubric item.\n            for tag in rubric_item.tags:\n                rubric_tag_items_grades[tag].append((rubric_item, grading_response))\n                assert tag not in curr_item_tags\n                curr_item_tags.add(tag)\n\n        rubric_tag_scores = {}\n        for tag, items_grades in rubric_tag_items_grades.items():\n            items, grades = zip(*items_grades)\n            score = calculate_score(items, grades)\n            if score is not None:  # implies at least one positive criterion\n                rubric_tag_scores[tag] = score\n        metrics.update(rubric_tag_scores)\n\n        # construct the list of explanations and grades\n        rubric_items_with_grades = []\n        readable_explanation_list = []\n        for rubric_item, grading_response in zip(rubric_items, grading_response_list):\n            explanation = grading_response.get(\"explanation\", \"No explanation provided\")\n            criteria_met = grading_response[\"criteria_met\"]\n            readable_explanation = (\n                f\"[{criteria_met}] {rubric_item}\\n\\tExplanation: {explanation}\"\n            )\n            readable_explanation_list.append(readable_explanation)\n            rubric_items_with_grades.append(\n                {\n                    **rubric_item.to_dict(),\n                    \"criteria_met\": criteria_met,\n                    \"explanation\": explanation,\n                }\n            )\n\n        readable_explanation_list.sort(\n            key=lambda x: x.startswith(\"[False]\"), reverse=True\n        )\n        readable_explanation_str = \"\\n\\n\".join(readable_explanation_list)\n        readable_explanation_str = f\"\\n\\n{readable_explanation_str}\"\n\n        return metrics, readable_explanation_str, rubric_items_with_grades\n\n    def __call__(self, sampler: SamplerBase) -> EvalResult:\n        def fn(row: dict):\n            prompt_messages = row[\"prompt\"]\n\n            if self.physician_completions_mode is not None:\n                response_text = row[\"completion_to_trial\"]\n                response_usage = None\n                actual_queried_prompt_messages = prompt_messages\n            else:\n                sampler_response = sampler(prompt_messages)\n                response_text = sampler_response.response_text\n                response_dict = sampler_response.response_metadata\n                actual_queried_prompt_messages = (\n                    sampler_response.actual_queried_message_list\n                )\n                response_usage = response_dict.get(\"usage\", None)\n\n            metrics, readable_explanation_str, rubric_items_with_grades = (\n                self.grade_sample(\n                    prompt=actual_queried_prompt_messages,\n                    response_text=response_text,\n                    rubric_items=row[\"rubrics\"],\n                    example_tags=row[\"example_tags\"],\n                )\n            )\n\n            score = metrics[\"overall_score\"]\n\n            # Create HTML for each sample result\n            html = report.jinja_env.from_string(\n                HEALTHBENCH_HTML_JINJA.replace(\n                    \"{{ rubric_grades }}\",\n                    readable_explanation_str.replace(\"\\n\", \"<br>\"),\n                )\n            ).render(\n                prompt_messages=actual_queried_prompt_messages,\n                next_message=dict(content=response_text, role=\"assistant\"),\n                score=metrics[\"overall_score\"],\n                extracted_answer=response_text,\n            )\n\n            convo = actual_queried_prompt_messages + [\n                dict(content=response_text, role=\"assistant\")\n            ]\n            return SingleEvalResult(\n                html=html,\n                score=score,\n                convo=convo,\n                metrics=metrics,\n                example_level_metadata={\n                    \"score\": score,\n                    \"usage\": get_usage_dict(response_usage),\n                    \"rubric_items\": rubric_items_with_grades,\n                    \"prompt\": actual_queried_prompt_messages,\n                    \"completion\": [dict(content=response_text, role=\"assistant\")],\n                    \"prompt_id\": row[\"prompt_id\"],\n                    \"completion_id\": hashlib.sha256(\n                        (row[\"prompt_id\"] + response_text).encode(\"utf-8\")\n                    ).hexdigest(),\n                },\n            )\n\n        results = report.map_with_progress(\n            fn,\n            self.examples,\n            num_threads=self.n_threads,\n            pbar=True,\n        )\n        final_metrics = _aggregate_get_clipped_mean(results)\n        return final_metrics\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"HealthBenchEval specific run options, including e.g., running the eval on physician completions rows only.\"\n    )\n    parser.add_argument(\n        \"--run_mode\",\n        type=str,\n        choices=[\"physician_completions\", \"physician_completion_references\"],\n    )\n    parser.add_argument(\"--examples\", type=int, help=\"Number of examples to run\")\n    parser.add_argument(\n        \"--n-threads\",\n        type=int,\n        default=120,\n        help=\"Number of threads to run\",\n    )\n    args = parser.parse_args()\n\n    if args.run_mode == \"physician_completions\":\n        physician_completions_main(\n            run_reference_completions=False,\n            num_examples=args.examples,\n            n_threads=args.n_threads or 1,\n        )\n    elif args.run_mode == \"physician_completion_references\":\n        physician_completions_main(\n            run_reference_completions=True,\n            num_examples=args.examples,\n            n_threads=args.n_threads or 1,\n        )\n\n    else:\n        raise ValueError(f\"Invalid run mode: {args.run_mode}\")\n\n\ndef physician_completions_main(\n    run_reference_completions: bool = False,\n    num_examples: int | None = None,\n    n_threads: int = 120,\n):\n    now = datetime.now()\n    date_str = now.strftime(\"%Y%m%d_%H%M\")\n\n    grading_sampler = ChatCompletionsSampler(\n        model=\"gpt-4.1-2025-04-14\",\n        system_message=OPENAI_SYSTEM_MESSAGE_API,\n        max_tokens=2048,\n        base_url=\"https://api.openai.com/v1\",\n    )\n    dummy_sampler = SamplerBase()\n\n    merge_metrics = []\n    for pc_mode in PHYSICIAN_COMPLETION_MODES.keys():\n        if (\n            run_reference_completions\n            and not PHYSICIAN_COMPLETION_MODES[pc_mode][\"has_reference\"]\n        ):\n            continue\n\n        # run\n        eval = HealthBenchEval(\n            grader_model=grading_sampler,\n            physician_completions_mode=pc_mode,\n            run_reference_completions=run_reference_completions,\n            num_examples=num_examples,\n            n_threads=n_threads,\n        )\n        result = eval(dummy_sampler)\n\n        # report\n        parsable_mode = PHYSICIAN_COMPLETION_MODES[pc_mode][\"short_name\"]\n        if run_reference_completions:\n            file_stem = f\"healthbench_{parsable_mode}_referencecompletions_{date_str}\"\n        else:\n            file_stem = f\"healthbench_{parsable_mode}_humanbaseline_{date_str}\"\n        report_filename = Path(f\"/tmp/{file_stem}.html\")\n        report_filename.write_text(report.make_report(result))\n        print(f\"Report saved to {report_filename}\")\n\n        # metrics\n        assert result.metrics is not None\n        metrics = result.metrics\n        result_filename = Path(f\"/tmp/{file_stem}.json\")\n        result_filename.write_text(json.dumps(metrics))\n        print(f\"Results saved to {result_filename}\")\n\n        full_result_dict = {\n            \"score\": result.score,\n            \"metrics\": result.metrics,\n            \"htmls\": result.htmls,\n            \"convos\": result.convos,\n            \"metadata\": result.metadata,\n        }\n        full_result_filename = Path(f\"/tmp/{file_stem}_allresults.json\")\n        full_result_filename.write_text(json.dumps(full_result_dict, indent=2))\n        print(f\"All results saved to {full_result_filename}\")\n\n        # metrics df\n        merge_metrics.append(\n            {\n                \"eval_name\": \"healthbench\",\n                \"model_name\": f\"{pc_mode} ({PHYSICIAN_COMPLETION_MODES[pc_mode]['description']})\",\n                \"metric\": metrics.get(\"overall_score\", None),\n            }\n        )\n\n    print(\"\\nAll results: \")\n    print(merge_metrics)\n    return merge_metrics\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "gpt_oss/evals/report.py",
    "content": "import os\nfrom collections import defaultdict\nfrom multiprocessing.pool import ThreadPool\nfrom typing import Any, Callable\n\nimport jinja2\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom .types import EvalResult, Message, SingleEvalResult\n\n\nHTML_JINJA = \"\"\"\n<h3>Prompt conversation</h3>\n{% for message in prompt_messages %}\n{{ message_to_html(message) | safe }}\n{% endfor %}\n<h3>Sampled message</h3>\n{{ message_to_html(next_message) | safe }}\n<h3>Results</h3>\n<p>Correct Answer: {{ correct_answer }}</p>\n<p>Extracted Answer: {{ extracted_answer }}</p>\n<p>Score: {{ score }}</p>\n\"\"\"\n\n\ndef _compute_stat(values: list, stat: str):\n    if stat == \"mean\":\n        return np.mean(values)\n    elif stat == \"std\":\n        return np.std(values)\n    elif stat == \"min\":\n        return np.min(values)\n    elif stat == \"max\":\n        return np.max(values)\n    elif stat == \"n_samples\":\n        return len(values)\n    elif stat == \"bootstrap_std\":\n        return np.std(\n            [np.mean(np.random.choice(values, len(values))) for _ in range(1000)]\n        )\n    else:\n        raise ValueError(f\"Unknown {stat =}\")\n\n\ndef aggregate_results(\n    single_eval_results: list[SingleEvalResult],\n    default_stats: tuple[str, ...] = (\"mean\", \"std\"),\n    name2stats: dict[str, tuple[str]] | None = None,\n) -> EvalResult:\n    \"\"\"\n    Aggregate results from multiple evaluations into a single EvalResult.\n    \"\"\"\n    name2stats = name2stats or {}\n    name2values = defaultdict(list)\n    htmls = []\n    convos = []\n    metadata = []\n    for single_eval_result in single_eval_results:\n        for name, value in single_eval_result.metrics.items():\n            name2values[name].append(value)\n        if single_eval_result.score is not None:\n            name2values[\"score\"].append(single_eval_result.score)\n        htmls.append(single_eval_result.html)\n        convos.append(single_eval_result.convo)\n        metadata.append(single_eval_result.example_level_metadata)\n    final_metrics = {}\n    for name, values in name2values.items():\n        stats = name2stats.get(name, default_stats)\n        for stat in stats:\n            key = name if stat == \"mean\" else f\"{name}:{stat}\"\n            final_metrics[key] = _compute_stat(values, stat)\n    return EvalResult(\n        score=final_metrics.pop(\"score\", None),\n        metrics=final_metrics,\n        htmls=htmls,\n        convos=convos,\n        metadata={\"example_level_metadata\": metadata},\n    )\n\n\ndef map_with_progress(\n    f: Callable,\n    xs: list[Any],\n    num_threads: int = 128,\n    pbar: bool = True,\n):\n    \"\"\"\n    Apply f to each element of xs, using a ThreadPool, and show progress.\n    \"\"\"\n    pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x\n\n    if os.getenv(\"debug\"):\n        return list(map(f, pbar_fn(xs, total=len(xs))))\n    else:\n        with ThreadPool(min(num_threads, len(xs))) as pool:\n            return list(pbar_fn(pool.imap_unordered(f, xs), total=len(xs)))\n\n\njinja_env = jinja2.Environment(\n    loader=jinja2.BaseLoader(),\n    undefined=jinja2.StrictUndefined,\n    autoescape=jinja2.select_autoescape([\"html\", \"xml\"]),\n)\n_message_template = \"\"\"\n<div class=\"message {{ role }}\">\n    <div class=\"role\">\n    {{ role }}\n    {% if variant %}<span class=\"variant\">({{ variant }})</span>{% endif %}\n    </div>\n    <div class=\"content\">\n    <pre>{{ content }}</pre>\n    </div>\n</div>\n\"\"\"\n\n\ndef message_to_html(message: Message) -> str:\n    \"\"\"\n    Generate HTML snippet (inside a <div>) for a message.\n    \"\"\"\n    return jinja_env.from_string(_message_template).render(\n        role=message[\"role\"],\n        content=message[\"content\"],\n        variant=message.get(\"variant\", None),\n    )\n\n\njinja_env.globals[\"message_to_html\"] = message_to_html\n\n\n_report_template = \"\"\"<!DOCTYPE html>\n<html>\n    <head>\n        <meta charset=\"utf-8\">\n        <style>\n            .message {\n                padding: 8px 16px;\n                margin-bottom: 8px;\n                border-radius: 4px;\n            }\n            .message.user {\n                background-color: #B2DFDB;\n                color: #00695C;\n            }\n            .message.assistant {\n                background-color: #B39DDB;\n                color: #4527A0;\n            }\n            .message.system {\n                background-color: #EEEEEE;\n                color: #212121;\n            }\n            .role {\n                font-weight: bold;\n                margin-bottom: 4px;\n            }\n            .variant {\n                color: #795548;\n            }\n            table, th, td {\n                border: 1px solid black;\n            }\n            pre {\n                white-space: pre-wrap;\n            }\n        </style>\n    </head>\n    <body>\n    {% if metrics %}\n    <h1>Metrics</h1>\n    <table>\n    <tr>\n        <th>Metric</th>\n        <th>Value</th>\n    </tr>\n    <tr>\n        <td><b>Score</b></td>\n        <td>{{ score | float | round(3) }}</td>\n    </tr>\n    {% for name, value in metrics.items() %}\n    <tr>\n        <td>{{ name }}</td>\n        <td>{{ value }}</td>\n    </tr>\n    {% endfor %}\n    </table>\n    {% endif %}\n    <h1>Examples</h1>\n    {% for html in htmls %}\n    {{ html | safe }}\n    <hr>\n    {% endfor %}\n    </body>\n</html>\n\"\"\"\n\n\ndef make_report(eval_result: EvalResult) -> str:\n    \"\"\"\n    Create a standalone HTML report from an EvalResult.\n    \"\"\"\n    return jinja_env.from_string(_report_template).render(\n        score=eval_result.score,\n        metrics=eval_result.metrics,\n        htmls=eval_result.htmls,\n    )\n"
  },
  {
    "path": "gpt_oss/evals/responses_sampler.py",
    "content": "import time\nfrom typing import Any\n\nimport openai\nfrom openai import OpenAI\n\nfrom .types import MessageList, SamplerBase, SamplerResponse\n\n\nclass ResponsesSampler(SamplerBase):\n    \"\"\"\n    Sample from OpenAI's responses API\n    \"\"\"\n\n    def __init__(\n        self,\n        model: str,\n        developer_message: str | None = None,\n        temperature: float = 1.0,\n        max_tokens: int = 131_072,\n        reasoning_model: bool = False,\n        reasoning_effort: str | None = None,\n        base_url: str = \"http://localhost:8000/v1\",\n    ):\n        self.client = OpenAI(base_url=base_url, timeout=24*60*60)\n        self.model = model\n        self.developer_message = developer_message\n        self.temperature = temperature\n        self.max_tokens = max_tokens\n        self.image_format = \"url\"\n        self.reasoning_model = reasoning_model\n        self.reasoning_effort = reasoning_effort\n\n    def _pack_message(self, role: str, content: Any) -> dict[str, Any]:\n        return {\"role\": role, \"content\": content}\n\n    def __call__(self, message_list: MessageList) -> SamplerResponse:\n        if self.developer_message:\n            message_list = [\n                self._pack_message(\"developer\", self.developer_message)\n            ] + message_list\n        trial = 0\n        while True:\n            try:\n                request_kwargs = {\n                    \"model\": self.model,\n                    \"input\": message_list,\n                    \"temperature\": self.temperature,\n                    \"max_output_tokens\": self.max_tokens,\n                }\n                if self.reasoning_model:\n                    request_kwargs[\"reasoning\"] = (\n                        {\"effort\": self.reasoning_effort} if self.reasoning_effort else None\n                    )\n                response = self.client.responses.create(**request_kwargs)\n\n                for output in response.output:\n                    if hasattr(output, \"text\"):\n                        message_list.append(self._pack_message(getattr(output, \"role\", \"assistant\"), output.text))\n                    elif hasattr(output, \"content\"):\n                        for c in output.content:\n                            # c.text handled below\n                            pass\n\n                return SamplerResponse(\n                    response_text=response.output_text,\n                    response_metadata={\"usage\": response.usage},\n                    actual_queried_message_list=message_list,\n                )\n            except openai.BadRequestError as e:\n                print(\"Bad Request Error\", e)\n                return SamplerResponse(\n                    response_text=\"\",\n                    response_metadata={\"usage\": None},\n                    actual_queried_message_list=message_list,\n                )\n            except Exception as e:\n                exception_backoff = 2**trial  # expontial back off\n                print(\n                    f\"Rate limit exception so wait and retry {trial} after {exception_backoff} sec\",\n                    e,\n                )\n                time.sleep(exception_backoff)\n                trial += 1\n            # unknown error shall throw exception\n"
  },
  {
    "path": "gpt_oss/evals/types.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import Any, Literal, overload\n\nMessage = dict[str, Any]  # keys role, content\nMessageList = list[Message]\n\n\n\n@dataclass\nclass SamplerResponse:\n    \"\"\"\n    Response from a sampler.\n    \"\"\"\n    response_text: str\n    actual_queried_message_list: MessageList\n    response_metadata: dict[str, Any]\n\nclass SamplerBase:\n    \"\"\"\n    Base class for defining a sampling model, which can be evaluated,\n    or used as part of the grading process.\n    \"\"\"\n\n    def __call__(\n        self, \n        message_list: MessageList,\n    ) -> SamplerResponse:\n        raise NotImplementedError\n\n\n@dataclass\nclass EvalResult:\n    \"\"\"\n    Result of running an evaluation (usually consisting of many samples)\n    \"\"\"\n\n    score: float | None  # top-line metric\n    metrics: dict[str, float] | None  # other metrics\n    htmls: list[str]  # strings of valid HTML\n    convos: list[MessageList]  # sampled conversations\n    metadata: dict[str, Any] | None  # Extra data such as rubric scores or sollen\n\n\n@dataclass\nclass SingleEvalResult:\n    \"\"\"\n    Result of evaluating a single sample\n    \"\"\"\n\n    score: float | None\n    metrics: dict[str, float] = field(default_factory=dict)\n    html: str | None = None\n    convo: MessageList | None = None  # sampled conversation\n    example_level_metadata: dict[str, Any] | None = (\n        None  # Extra data such as rubric scores or sollen\n    )\n\n\nclass Eval:\n    \"\"\"\n    Base class for defining an evaluation.\n    \"\"\"\n\n    def __call__(self, sampler: SamplerBase) -> EvalResult:\n        raise NotImplementedError\n\n"
  },
  {
    "path": "gpt_oss/generate.py",
    "content": "# Model parallel inference\n# Note: This script is for demonstration purposes only. It is not designed for production use.\n#       See gpt_oss.chat for a more complete example with the Harmony parser.\n# torchrun --nproc-per-node=4 -m gpt_oss.generate -p \"why did the chicken cross the road?\" model/\n\nimport argparse\n\nfrom gpt_oss.tokenizer import get_tokenizer\n\n\ndef main(args):\n    match args.backend:\n        case \"torch\":\n            from gpt_oss.torch.utils import init_distributed\n            from gpt_oss.torch.model import TokenGenerator as TorchGenerator\n            device = init_distributed()\n            generator = TorchGenerator(args.checkpoint, device=device)\n        case \"triton\":\n            from gpt_oss.torch.utils import init_distributed\n            from gpt_oss.triton.model import TokenGenerator as TritonGenerator\n            device = init_distributed()\n            generator = TritonGenerator(args.checkpoint, context=args.context_length, device=device)\n        case \"vllm\":\n            from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator\n            generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size)\n        case _:\n            raise ValueError(f\"Invalid backend: {args.backend}\")\n\n    tokenizer = get_tokenizer()\n    tokens = tokenizer.encode(args.prompt)\n    max_tokens = None if args.limit == 0 else args.limit\n    for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=max_tokens, return_logprobs=True):\n        tokens.append(token)\n        token_text = tokenizer.decode([token])\n        print(\n            f\"Generated token: {repr(token_text)}, logprob: {logprob}\"\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Text generation example\")\n    parser.add_argument(\n        \"checkpoint\",\n        metavar=\"FILE\",\n        type=str,\n        help=\"Path to the SafeTensors checkpoint\",\n    )\n    parser.add_argument(\n        \"-p\",\n        \"--prompt\",\n        metavar=\"PROMPT\",\n        type=str,\n        default=\"How are you?\",\n        help=\"LLM prompt\",\n    )\n    parser.add_argument(\n        \"-t\",\n        \"--temperature\",\n        metavar=\"TEMP\",\n        type=float,\n        default=0.0,\n        help=\"Sampling temperature\",\n    )\n    parser.add_argument(\n        \"-l\",\n        \"--limit\",\n        metavar=\"LIMIT\",\n        type=int,\n        default=0,\n        help=\"Limit on the number of tokens (0 to disable)\",\n    )\n    parser.add_argument(\n        \"-b\",\n        \"--backend\",\n        metavar=\"BACKEND\",\n        type=str,\n        default=\"torch\",\n        choices=[\"triton\", \"torch\", \"vllm\"],\n        help=\"Inference backend\",\n    )\n    parser.add_argument(\n        \"--tensor-parallel-size\",\n        type=int,\n        default=2,\n        help=\"Tensor parallel size for vLLM backend\",\n    )\n    parser.add_argument(\n        \"--context-length\",\n        type=int,\n        default=4096,\n        help=\"Context length for Triton backend\",\n    )\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "gpt_oss/metal/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.24)\nproject(GPTOSS\n    VERSION 1.0\n    DESCRIPTION \"Local GPT-OSS inference\"\n    LANGUAGES C CXX OBJC)\n\nset(CMAKE_C_STANDARD 11)\nset(CMAKE_CXX_STANDARD 20)\nset(CMAKE_OBJC_STANDARD 11)\nset(CMAKE_OBJC_STANDARD_REQUIRED ON)\n\nfind_library(FOUNDATION_FRAMEWORK Foundation REQUIRED)\nfind_library(METAL_FRAMEWORK      Metal      REQUIRED)\nfind_library(IOKIT_FRAMEWORK      IOKit      REQUIRED)\n\nset(METAL_SOURCES\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal\n    ${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal\n)\nset(METAL_LIB default.metallib)\n\ninclude_directories(BEFORE include source/include)\n\nadd_custom_command(\n    OUTPUT  ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}\n    COMMAND ${CMAKE_COMMAND} -E make_directory \"${CMAKE_CURRENT_BINARY_DIR}/source/\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/convert.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/random.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/rope.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/sample.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air\"\n    COMMAND xcrun -sdk macosx metal -g \"-I${CMAKE_CURRENT_SOURCE_DIR}/source/include\" -c \"${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal\" -o \"${CMAKE_CURRENT_BINARY_DIR}/source/topk.air\"\n    COMMAND xcrun -sdk macosx metallib \"${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/convert.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/random.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/rope.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/sample.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air\" \"${CMAKE_CURRENT_BINARY_DIR}/source/topk.air\" -o \"${METAL_LIB}\"\n    DEPENDS ${METAL_SOURCES}\n    COMMENT \"Compiling Metal compute library\"\n)\n\nadd_custom_target(build_metallib ALL\n    DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB})\n\nadd_library(log OBJECT source/log.c)\n\nadd_library(metal-kernels STATIC source/metal.m source/metal-kernels.c)\ntarget_link_libraries(metal-kernels PRIVATE log)\n\nadd_dependencies(metal-kernels build_metallib)\nadd_custom_command(TARGET metal-kernels POST_BUILD\n    COMMAND ${CMAKE_COMMAND} -E copy\n            ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}\n            $<TARGET_FILE_DIR:metal-kernels>)\n\ntarget_link_libraries(metal-kernels PRIVATE ${FOUNDATION_FRAMEWORK} ${METAL_FRAMEWORK} ${IOKIT_FRAMEWORK})\n\nadd_library(gptoss STATIC source/model.c source/tokenizer.c source/context.c)\ntarget_link_libraries(gptoss PRIVATE log metal-kernels)\n\nadd_executable(generate source/generate.c)\ntarget_link_libraries(generate gptoss)\n\n# --- [ Tests\ninclude(FetchContent)\nFetchContent_Declare(\n    googletest\n    URL https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip\n    DOWNLOAD_EXTRACT_TIMESTAMP OFF\n)\n# For Windows: Prevent overriding the parent project's compiler/linker settings\nset(gtest_force_shared_crt ON CACHE BOOL \"\" FORCE)\nset(INSTALL_GTEST OFF CACHE BOOL \"\" FORCE)\nFetchContent_MakeAvailable(googletest)\n\nenable_testing()\n\nadd_executable(u32-random-test test/u32-random.cc)\ntarget_link_libraries(u32-random-test PRIVATE GTest::gtest_main metal-kernels)\ntarget_include_directories(u32-random-test PRIVATE source/include)\nadd_test(NAME u32-random-test COMMAND u32-random-test)\n\nadd_executable(f32-random-test test/f32-random.cc)\ntarget_link_libraries(f32-random-test PRIVATE GTest::gtest_main metal-kernels)\ntarget_include_directories(f32-random-test PRIVATE source/include)\nadd_test(NAME f32-random-test COMMAND f32-random-test)\n\nadd_executable(mf4-f32-convert-test test/mf4-f32-convert.cc)\ntarget_link_libraries(mf4-f32-convert-test PRIVATE GTest::gtest_main metal-kernels)\ntarget_include_directories(mf4-f32-convert-test PRIVATE source/include)\nadd_test(NAME mf4-f32-convert-test COMMAND mf4-f32-convert-test)\n\nadd_executable(bf16-f32-embeddings-test test/bf16-f32-embeddings.cc)\ntarget_link_libraries(bf16-f32-embeddings-test PRIVATE GTest::gtest_main metal-kernels)\ntarget_include_directories(bf16-f32-embeddings-test PRIVATE source/include)\nadd_test(NAME bf16-f32-embeddings-test COMMAND bf16-f32-embeddings-test)\n\nadd_executable(f32-bf16w-rmsnorm-test test/f32-bf16w-rmsnorm.cc)\ntarget_link_libraries(f32-bf16w-rmsnorm-test PRIVATE GTest::gtest_main metal-kernels)\ntarget_include_directories(f32-bf16w-rmsnorm-test PRIVATE source/include)\nadd_test(NAME f32-bf16w-rmsnorm-test COMMAND f32-bf16w-rmsnorm-test)\n\nadd_executable(f32-bf16w-matmul-test test/f32-bf16w-matmul.cc)\ntarget_link_libraries(f32-bf16w-matmul-test PRIVATE GTest::gtest_main metal-kernels)\ntarget_include_directories(f32-bf16w-matmul-test PRIVATE source/include)\nadd_test(NAME f32-bf16w-matmul-test COMMAND f32-bf16w-matmul-test)\n\nadd_executable(f32-rope-test test/f32-rope.cc)\ntarget_link_libraries(f32-rope-test PRIVATE GTest::gtest_main metal-kernels)\ntarget_include_directories(f32-rope-test PRIVATE source/include)\nadd_test(NAME f32-rope-test COMMAND f32-rope-test)\n\n# --- [ Benchmarks\ninclude(FetchContent)\nset(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL \"Disable self-tests in Google Benchmark\" FORCE)\nset(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL \"Disable installation of Google Benchmark\" FORCE)\nFetchContent_Declare(\n    benchmark\n    URL https://github.com/google/benchmark/archive/refs/tags/v1.9.4.zip\n    DOWNLOAD_EXTRACT_TIMESTAMP OFF\n)\nFetchContent_MakeAvailable(benchmark)\n\nadd_executable(f32-random-bench benchmark/f32-random.cc)\ntarget_link_libraries(f32-random-bench PRIVATE benchmark::benchmark metal-kernels)\ntarget_include_directories(f32-random-bench PRIVATE source/include)\n\nadd_executable(u32-random-bench benchmark/u32-random.cc)\ntarget_link_libraries(u32-random-bench PRIVATE benchmark::benchmark metal-kernels)\ntarget_include_directories(u32-random-bench PRIVATE source/include)\n\nadd_executable(mf4-f32-convert-bench benchmark/mf4-f32-convert.cc)\ntarget_link_libraries(mf4-f32-convert-bench PRIVATE benchmark::benchmark metal-kernels)\ntarget_include_directories(mf4-f32-convert-bench PRIVATE source/include)\n\nadd_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc)\ntarget_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels)\ntarget_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include)\n\nadd_executable(end-to-end-bench benchmark/end-to-end.cc)\ntarget_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss)\ntarget_include_directories(end-to-end-bench PRIVATE source/include)\n\nadd_executable(end-to-end-threadgroup-bench benchmark/end-to-end-threadgroup.cc)\ntarget_link_libraries(end-to-end-threadgroup-bench PRIVATE benchmark::benchmark gptoss)\ntarget_include_directories(end-to-end-threadgroup-bench PRIVATE source/include)\n\n# --- [ Python extension ] -----------------------------------------------\nfind_package(pybind11 CONFIG REQUIRED)          # provides pybind11_add_module\n\npybind11_add_module(_metal\n    python/module.c\n    python/context.c\n    python/model.c\n    python/tokenizer.c\n)\nset_target_properties(_metal PROPERTIES PREFIX \"\")\n\ntarget_link_libraries(_metal PRIVATE gptoss)\nadd_dependencies(_metal build_metallib)\ntarget_link_options(_metal PRIVATE\n    LINKER:-sectcreate,__METAL,__shaders,${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}\n)\nadd_custom_command(TARGET _metal POST_BUILD\n    COMMAND ${CMAKE_COMMAND} -E copy\n            ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}\n            $<TARGET_FILE_DIR:_metal>)\n\n# 1️⃣  install the extension module into the Python package\ninstall(TARGETS _metal LIBRARY DESTINATION gpt_oss/metal)\n\n# 2️⃣  make sure the Metal shader archive travels with it\ninstall(FILES ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}\n        DESTINATION gpt_oss/metal)\n# ------------------------------------------------------------------------"
  },
  {
    "path": "gpt_oss/metal/__init__.py",
    "content": "from importlib import import_module as _im\n\n# Load the compiled extension (gpt_oss.metal._metal)\n_ext = _im(f\"{__name__}._metal\")\nglobals().update({k: v for k, v in _ext.__dict__.items() if not k.startswith(\"_\")})\ndel _im, _ext\n"
  },
  {
    "path": "gpt_oss/metal/benchmark/end-to-end-threadgroup.cc",
    "content": "#include <gpt-oss.h>\n#include <internal/model.h>\n\n#include <array>\n#include <cstdint>\n#include <cstddef>\n#include <format>\n#include <limits>\n#include <memory>\n#include <string>\n#include <type_traits>\n\n#include <benchmark/benchmark.h>\n\n\nconstexpr std::uint32_t kNumGeneratedTokens = 100;\n\n\nstatic void attn_qkv_tgsize(benchmark::State& state, const char* env_var_name) {\n    const char* model_path = getenv(env_var_name);\n    if (model_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\", env_var_name));\n        return;\n    }\n\n    gptoss_model_t model_ptr = nullptr;\n    gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to load model from file {}\", model_path));\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_model_t>, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);\n    model->attn_qkv_threadgroup_size = static_cast<std::size_t>(state.range(0));\n\n    gptoss_context_t context_ptr = nullptr;\n    status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to create Context object\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_context_t>, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);\n\n    const char* prompt = \"why did the chicken cross the road?\";\n    std::size_t num_prompt_tokens = 0;\n    status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to tokenize prompt \\\"{}\\\"\", prompt));\n        return;\n    }\n\n    // Prefill\n    status = gptoss_context_process(context.get());\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to prefill Context object\");\n        return;\n    }\n    const std::size_t num_kvcache_tokens = context->num_kv_tokens;\n\n    std::uint64_t rng_seed = 0;\n    for (auto _ : state) {\n        const std::uint64_t current_rng_seed = rng_seed++;\n        context->num_kv_tokens = num_prompt_tokens;\n        context->num_tokens = num_prompt_tokens;\n\n        std::array<std::uint32_t, kNumGeneratedTokens> tokens;\n        std::size_t num_generated_tokens = 0;\n        do {\n            std::size_t num_current_generated_tokens = 0;\n            status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,\n                /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);\n            if (status != gptoss_status_success) {\n                state.SkipWithError(\"failed to sample from the Context object\");\n                return;\n            }\n            num_generated_tokens += num_current_generated_tokens;\n        } while (num_generated_tokens < kNumGeneratedTokens);\n    }\n\n    state.counters[\"generations\"] =\n        benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);\n    state.counters[\"tokens\"] =\n        benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);\n}\n\nstatic void AttnQKVThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {\n    b->ArgNames({\"tgsize\"});\n    for (auto attn_qkv_threadgroup_size = 32; attn_qkv_threadgroup_size <= 1024; attn_qkv_threadgroup_size += 32) {\n        const auto num_simdgroups = attn_qkv_threadgroup_size / 32;\n        if (5120 % num_simdgroups != 0) {\n            // Skip incompatible threadgroup sizes\n            continue;\n        }\n        b->Args({attn_qkv_threadgroup_size});\n    }\n}\n\nBENCHMARK_CAPTURE(attn_qkv_tgsize, gpt_oss_20b, \"GPT_OSS_20B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnQKVThreadgroupSizeArguments);\nBENCHMARK_CAPTURE(attn_qkv_tgsize, gpt_oss_120b, \"GPT_OSS_120B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnQKVThreadgroupSizeArguments);\n\nstatic void attn_out_tgsize(benchmark::State& state, const char* env_var_name) {\n    const char* model_path = getenv(env_var_name);\n    if (model_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\", env_var_name));\n        return;\n    }\n\n    gptoss_model_t model_ptr = nullptr;\n    gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to load model from file {}\", model_path));\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_model_t>, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);\n    model->attn_out_threadgroup_size = static_cast<std::size_t>(state.range(0));\n\n    gptoss_context_t context_ptr = nullptr;\n    status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to create Context object\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_context_t>, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);\n\n    const char* prompt = \"why did the chicken cross the road?\";\n    std::size_t num_prompt_tokens = 0;\n    status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to tokenize prompt \\\"{}\\\"\", prompt));\n        return;\n    }\n\n    // Prefill\n    status = gptoss_context_process(context.get());\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to prefill Context object\");\n        return;\n    }\n    const std::size_t num_kvcache_tokens = context->num_kv_tokens;\n\n    std::uint64_t rng_seed = 0;\n    for (auto _ : state) {\n        const std::uint64_t current_rng_seed = rng_seed++;\n        context->num_kv_tokens = num_prompt_tokens;\n        context->num_tokens = num_prompt_tokens;\n\n        std::array<std::uint32_t, kNumGeneratedTokens> tokens;\n        std::size_t num_generated_tokens = 0;\n        do {\n            std::size_t num_current_generated_tokens = 0;\n            status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,\n                /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);\n            if (status != gptoss_status_success) {\n                state.SkipWithError(\"failed to sample from the Context object\");\n                return;\n            }\n            num_generated_tokens += num_current_generated_tokens;\n        } while (num_generated_tokens < kNumGeneratedTokens);\n    }\n\n    state.counters[\"generations\"] =\n        benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);\n    state.counters[\"tokens\"] =\n        benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);\n}\n\nstatic void AttnOutThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {\n    b->ArgNames({\"tgsize\"});\n    for (auto attn_out_threadgroup_size = 32; attn_out_threadgroup_size <= 1024; attn_out_threadgroup_size += 32) {\n        const auto num_simdgroups = attn_out_threadgroup_size / 32;\n        if (2880 % num_simdgroups != 0) {\n            // Skip incompatible threadgroup sizes\n            continue;\n        }\n        b->Args({attn_out_threadgroup_size});\n    }\n}\n\nBENCHMARK_CAPTURE(attn_out_tgsize, gpt_oss_20b, \"GPT_OSS_20B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnOutThreadgroupSizeArguments);\nBENCHMARK_CAPTURE(attn_out_tgsize, gpt_oss_120b, \"GPT_OSS_120B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnOutThreadgroupSizeArguments);\n\nstatic void mlp_gate_tgsize(benchmark::State& state, const char* env_var_name) {\n    const char* model_path = getenv(env_var_name);\n    if (model_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\", env_var_name));\n        return;\n    }\n\n    gptoss_model_t model_ptr = nullptr;\n    gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to load model from file {}\", model_path));\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_model_t>, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);\n    model->mlp_gate_threadgroup_size = static_cast<std::size_t>(state.range(0));\n\n    gptoss_context_t context_ptr = nullptr;\n    status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to create Context object\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_context_t>, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);\n\n    const char* prompt = \"why did the chicken cross the road?\";\n    std::size_t num_prompt_tokens = 0;\n    status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to tokenize prompt \\\"{}\\\"\", prompt));\n        return;\n    }\n\n    // Prefill\n    status = gptoss_context_process(context.get());\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to prefill Context object\");\n        return;\n    }\n    const std::size_t num_kvcache_tokens = context->num_kv_tokens;\n\n    std::uint64_t rng_seed = 0;\n    for (auto _ : state) {\n        const std::uint64_t current_rng_seed = rng_seed++;\n        context->num_kv_tokens = num_prompt_tokens;\n        context->num_tokens = num_prompt_tokens;\n\n        std::array<std::uint32_t, kNumGeneratedTokens> tokens;\n        std::size_t num_generated_tokens = 0;\n        do {\n            std::size_t num_current_generated_tokens = 0;\n            status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,\n                /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);\n            if (status != gptoss_status_success) {\n                state.SkipWithError(\"failed to sample from the Context object\");\n                return;\n            }\n            num_generated_tokens += num_current_generated_tokens;\n        } while (num_generated_tokens < kNumGeneratedTokens);\n    }\n\n    state.counters[\"generations\"] =\n        benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);\n    state.counters[\"tokens\"] =\n        benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);\n}\n\nstatic void MlpGateThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {\n    b->ArgNames({\"tgsize\"});\n    for (auto mlp_gate_threadgroup_size = 32; mlp_gate_threadgroup_size <= 1024; mlp_gate_threadgroup_size += 32) {\n        const auto num_simdgroups = mlp_gate_threadgroup_size / 32;\n        if (128 % num_simdgroups != 0) {\n            // Skip incompatible threadgroup sizes\n            continue;\n        }\n        b->Args({mlp_gate_threadgroup_size});\n    }\n}\n\nBENCHMARK_CAPTURE(mlp_gate_tgsize, gpt_oss_20b, \"GPT_OSS_20B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpGateThreadgroupSizeArguments);\nBENCHMARK_CAPTURE(mlp_gate_tgsize, gpt_oss_120b, \"GPT_OSS_120B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpGateThreadgroupSizeArguments);\n\nstatic void mlp_swiglu_tgsize(benchmark::State& state, const char* env_var_name) {\n    const char* model_path = getenv(env_var_name);\n    if (model_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\", env_var_name));\n        return;\n    }\n\n    gptoss_model_t model_ptr = nullptr;\n    gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to load model from file {}\", model_path));\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_model_t>, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);\n    model->mlp_swiglu_threadgroup_size = static_cast<std::size_t>(state.range(0));\n\n    gptoss_context_t context_ptr = nullptr;\n    status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to create Context object\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_context_t>, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);\n\n    const char* prompt = \"why did the chicken cross the road?\";\n    std::size_t num_prompt_tokens = 0;\n    status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to tokenize prompt \\\"{}\\\"\", prompt));\n        return;\n    }\n\n    // Prefill\n    status = gptoss_context_process(context.get());\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to prefill Context object\");\n        return;\n    }\n    const std::size_t num_kvcache_tokens = context->num_kv_tokens;\n\n    std::uint64_t rng_seed = 0;\n    for (auto _ : state) {\n        const std::uint64_t current_rng_seed = rng_seed++;\n        context->num_kv_tokens = num_prompt_tokens;\n        context->num_tokens = num_prompt_tokens;\n\n        std::array<std::uint32_t, kNumGeneratedTokens> tokens;\n        std::size_t num_generated_tokens = 0;\n        do {\n            std::size_t num_current_generated_tokens = 0;\n            status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,\n                /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);\n            if (status != gptoss_status_success) {\n                state.SkipWithError(\"failed to sample from the Context object\");\n                return;\n            }\n            num_generated_tokens += num_current_generated_tokens;\n        } while (num_generated_tokens < kNumGeneratedTokens);\n    }\n\n    state.counters[\"generations\"] =\n        benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);\n    state.counters[\"tokens\"] =\n        benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);\n}\n\nstatic void MlpSwigluThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {\n    b->ArgNames({\"tgsize\"});\n    for (auto threadgroup_size = 64; threadgroup_size <= 1024; threadgroup_size += 64) {\n        const auto num_simdgroups = threadgroup_size / 32;\n        if (5760 % num_simdgroups != 0) {\n            // Skip incompatible threadgroup sizes\n            continue;\n        }\n        b->Args({threadgroup_size});\n    }\n}\n\nBENCHMARK_CAPTURE(mlp_swiglu_tgsize, gpt_oss_20b, \"GPT_OSS_20B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpSwigluThreadgroupSizeArguments);\nBENCHMARK_CAPTURE(mlp_swiglu_tgsize, gpt_oss_120b, \"GPT_OSS_120B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpSwigluThreadgroupSizeArguments);\n\nstatic void mlp_out_tgsize(benchmark::State& state, const char* env_var_name) {\n    const char* model_path = getenv(env_var_name);\n    if (model_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\", env_var_name));\n        return;\n    }\n\n    gptoss_model_t model_ptr = nullptr;\n    gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to load model from file {}\", model_path));\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_model_t>, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);\n    model->mlp_out_threadgroup_size = static_cast<std::size_t>(state.range(0));\n\n    gptoss_context_t context_ptr = nullptr;\n    status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to create Context object\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_context_t>, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);\n\n    const char* prompt = \"why did the chicken cross the road?\";\n    std::size_t num_prompt_tokens = 0;\n    status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to tokenize prompt \\\"{}\\\"\", prompt));\n        return;\n    }\n\n    // Prefill\n    status = gptoss_context_process(context.get());\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to prefill Context object\");\n        return;\n    }\n    const std::size_t num_kvcache_tokens = context->num_kv_tokens;\n\n    std::uint64_t rng_seed = 0;\n    for (auto _ : state) {\n        const std::uint64_t current_rng_seed = rng_seed++;\n        context->num_kv_tokens = num_prompt_tokens;\n        context->num_tokens = num_prompt_tokens;\n\n        std::array<std::uint32_t, kNumGeneratedTokens> tokens;\n        std::size_t num_generated_tokens = 0;\n        do {\n            std::size_t num_current_generated_tokens = 0;\n            status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,\n                /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);\n            if (status != gptoss_status_success) {\n                state.SkipWithError(\"failed to sample from the Context object\");\n                return;\n            }\n            num_generated_tokens += num_current_generated_tokens;\n        } while (num_generated_tokens < kNumGeneratedTokens);\n    }\n\n    state.counters[\"generations\"] =\n        benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);\n    state.counters[\"tokens\"] =\n        benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);\n}\n\nstatic void MlpOutThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {\n    b->ArgNames({\"tgsize\"});\n    for (auto threadgroup_size = 64; threadgroup_size <= 1024; threadgroup_size += 64) {\n        const auto num_simdgroups = threadgroup_size / 32;\n        if (5760 % num_simdgroups != 0) {\n            // Skip incompatible threadgroup sizes\n            continue;\n        }\n        b->Args({threadgroup_size});\n    }\n}\n\nBENCHMARK_CAPTURE(mlp_out_tgsize, gpt_oss_20b, \"GPT_OSS_20B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpOutThreadgroupSizeArguments);\nBENCHMARK_CAPTURE(mlp_out_tgsize, gpt_oss_120b, \"GPT_OSS_120B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpOutThreadgroupSizeArguments);\n\nstatic void mlp_acc_tgsize(benchmark::State& state, const char* env_var_name) {\n    const char* model_path = getenv(env_var_name);\n    if (model_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\", env_var_name));\n        return;\n    }\n\n    gptoss_model_t model_ptr = nullptr;\n    gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to load model from file {}\", model_path));\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_model_t>, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);\n    model->mlp_acc_threadgroup_size = static_cast<std::size_t>(state.range(0));\n\n    gptoss_context_t context_ptr = nullptr;\n    status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to create Context object\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_context_t>, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);\n\n    const char* prompt = \"why did the chicken cross the road?\";\n    std::size_t num_prompt_tokens = 0;\n    status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to tokenize prompt \\\"{}\\\"\", prompt));\n        return;\n    }\n\n    // Prefill\n    status = gptoss_context_process(context.get());\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to prefill Context object\");\n        return;\n    }\n    const std::size_t num_kvcache_tokens = context->num_kv_tokens;\n\n    std::uint64_t rng_seed = 0;\n    for (auto _ : state) {\n        const std::uint64_t current_rng_seed = rng_seed++;\n        context->num_kv_tokens = num_prompt_tokens;\n        context->num_tokens = num_prompt_tokens;\n\n        std::array<std::uint32_t, kNumGeneratedTokens> tokens;\n        std::size_t num_generated_tokens = 0;\n        do {\n            std::size_t num_current_generated_tokens = 0;\n            status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,\n                /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);\n            if (status != gptoss_status_success) {\n                state.SkipWithError(\"failed to sample from the Context object\");\n                return;\n            }\n            num_generated_tokens += num_current_generated_tokens;\n        } while (num_generated_tokens < kNumGeneratedTokens);\n    }\n\n    state.counters[\"generations\"] =\n        benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);\n    state.counters[\"tokens\"] =\n        benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);\n}\n\nstatic void MlpAccThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {\n    b->ArgNames({\"tgsize\"});\n    for (auto threadgroup_size = 32; threadgroup_size <= 1024; threadgroup_size += 32) {\n        b->Args({threadgroup_size});\n    }\n}\n\nBENCHMARK_CAPTURE(mlp_acc_tgsize, gpt_oss_20b, \"GPT_OSS_20B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpAccThreadgroupSizeArguments);\nBENCHMARK_CAPTURE(mlp_acc_tgsize, gpt_oss_120b, \"GPT_OSS_120B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpAccThreadgroupSizeArguments);\n\nstatic void unembedding_tgsize(benchmark::State& state, const char* env_var_name) {\n    const char* model_path = getenv(env_var_name);\n    if (model_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\", env_var_name));\n        return;\n    }\n\n    gptoss_model_t model_ptr = nullptr;\n    gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to load model from file {}\", model_path));\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_model_t>, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);\n    model->unembedding_threadgroup_size = static_cast<std::size_t>(state.range(0));\n\n    gptoss_context_t context_ptr = nullptr;\n    status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to create Context object\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_context_t>, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);\n\n    const char* prompt = \"why did the chicken cross the road?\";\n    std::size_t num_prompt_tokens = 0;\n    status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to tokenize prompt \\\"{}\\\"\", prompt));\n        return;\n    }\n\n    // Prefill\n    status = gptoss_context_process(context.get());\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to prefill Context object\");\n        return;\n    }\n    const std::size_t num_kvcache_tokens = context->num_kv_tokens;\n\n    std::uint64_t rng_seed = 0;\n    for (auto _ : state) {\n        const std::uint64_t current_rng_seed = rng_seed++;\n        context->num_kv_tokens = num_prompt_tokens;\n        context->num_tokens = num_prompt_tokens;\n\n        std::array<std::uint32_t, kNumGeneratedTokens> tokens;\n        std::size_t num_generated_tokens = 0;\n        do {\n            std::size_t num_current_generated_tokens = 0;\n            status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,\n                /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);\n            if (status != gptoss_status_success) {\n                state.SkipWithError(\"failed to sample from the Context object\");\n                return;\n            }\n            num_generated_tokens += num_current_generated_tokens;\n        } while (num_generated_tokens < kNumGeneratedTokens);\n    }\n\n    state.counters[\"generations\"] =\n        benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);\n    state.counters[\"tokens\"] =\n        benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);\n}\n\nstatic void UnembeddingThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {\n    b->ArgNames({\"tgsize\"});\n    for (auto threadgroup_size = 32; threadgroup_size <= 1024; threadgroup_size += 32) {\n        b->Args({threadgroup_size});\n    }\n}\n\nBENCHMARK_CAPTURE(unembedding_tgsize, gpt_oss_20b, \"GPT_OSS_20B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(UnembeddingThreadgroupSizeArguments);\nBENCHMARK_CAPTURE(unembedding_tgsize, gpt_oss_120b, \"GPT_OSS_120B_PATH\")\n    ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(UnembeddingThreadgroupSizeArguments);\n\nBENCHMARK_MAIN();\n"
  },
  {
    "path": "gpt_oss/metal/benchmark/end-to-end.cc",
    "content": "#include <gpt-oss.h>\n#include <internal/model.h>\n\n#include <array>\n#include <cstddef>\n#include <cstdint>\n#include <format>\n#include <fstream>\n#include <limits>\n#include <memory>\n#include <string>\n#include <type_traits>\n\n#include <benchmark/benchmark.h>\n\nconstexpr std::uint32_t kNumGeneratedTokens = 100;\n\nstatic void end2end_decode(benchmark::State& state, const char* env_var_name) {\n    const char* model_path = getenv(env_var_name);\n    if (model_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\", env_var_name));\n        return;\n    }\n\n    gptoss_model_t model_ptr = nullptr;\n    gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to load model from file {}\", model_path));\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_model_t>, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);\n\n    gptoss_context_t context_ptr = nullptr;\n    status = gptoss_context_create(model.get(), /*context_length=*/0, /*max_batch_tokens=*/0, &context_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to create Context object\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_context_t>, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);\n\n    const char* prompt = \"why did the chicken cross the road?\";\n    std::size_t num_prompt_tokens = 0;\n    status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\"failed to tokenize prompt \\\"{}\\\"\", prompt));\n        return;\n    }\n\n    // Prefill\n    status = gptoss_context_process(context.get());\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to prefill Context object\");\n        return;\n    }\n    std::uint64_t rng_seed = 0;\n\n    for (auto _ : state) {\n        const std::uint64_t current_rng_seed = rng_seed++;\n        context->num_kv_tokens = num_prompt_tokens;\n        context->num_tokens = num_prompt_tokens;\n\n        std::array<std::uint32_t, kNumGeneratedTokens> tokens;\n        std::size_t num_generated_tokens = 0;\n        do {\n            std::size_t num_current_generated_tokens = 0;\n            status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,\n                                           /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);\n            if (status != gptoss_status_success) {\n                state.SkipWithError(\"failed to sample from the Context object\");\n                return;\n            }\n            num_generated_tokens += num_current_generated_tokens;\n        } while (num_generated_tokens < kNumGeneratedTokens);\n    }\n\n    state.counters[\"generations\"] =\n        benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);\n    state.counters[\"tokens\"] =\n        benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);\n}\n\nstatic void end2end_prefill(benchmark::State& state,\n                            const char* model_path_env_var_name,\n                            const char* prompt_env_var_name,\n                            size_t context_length = 0) {\n    const char* model_path = getenv(model_path_env_var_name);\n    if (model_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\",\n                                        model_path_env_var_name));\n        return;\n    }\n\n    const char* prompt_file_path = getenv(prompt_env_var_name);\n    if (prompt_file_path == NULL) {\n        state.SkipWithError(std::format(\"environment variable {} is not set\",\n                                        prompt_env_var_name));\n        return;\n    }\n\n    // Read prompt contents from file into a std::string\n    std::ifstream prompt_file(prompt_file_path,\n                              std::ios::in | std::ios::binary);\n    if (!prompt_file) {\n        state.SkipWithError(\n            std::format(\"failed to open prompt file {}\", prompt_file_path));\n        return;\n    }\n    std::string prompt_str;\n    prompt_file.seekg(0, std::ios::end);\n    std::streampos file_size = prompt_file.tellg();\n    if (file_size < 0) {\n        state.SkipWithError(std::format(\"failed to read prompt file size {}\",\n                                        prompt_file_path));\n        return;\n    }\n    prompt_str.resize(static_cast<std::size_t>(file_size));\n    prompt_file.seekg(0, std::ios::beg);\n    if (file_size > 0) {\n        prompt_file.read(prompt_str.data(), file_size);\n    }\n    if (!prompt_file) {\n        state.SkipWithError(\n            std::format(\"failed to read prompt file {}\", prompt_file_path));\n        return;\n    }\n\n    gptoss_model_t model_ptr = nullptr;\n    gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\n            std::format(\"failed to load model from file {}\", model_path));\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_model_t>,\n                    decltype(&gptoss_model_release)>\n        model(model_ptr, gptoss_model_release);\n\n    gptoss_tokenizer_t tokenizer_ptr = nullptr;\n    status = gptoss_model_get_tokenizer(model.get(), &tokenizer_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to retrieve Tokenizer\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_tokenizer_t>,\n                    decltype(&gptoss_tokenizer_release)>\n        tokenizer(tokenizer_ptr, gptoss_tokenizer_release);\n\n    gptoss_context_t context_ptr = nullptr;\n    status = gptoss_context_create(model.get(),\n                                   /*context_lenght=*/0,\n                                   /*max_batch_tokens=*/1024,\n                                   &context_ptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to create Context object\");\n        return;\n    }\n    std::unique_ptr<std::remove_pointer_t<gptoss_context_t>,\n                    decltype(&gptoss_context_release)>\n        context(context_ptr, gptoss_context_release);\n\n    const char* prompt = prompt_str.c_str();\n    status = gptoss_context_append_chars(context.get(), prompt,\n                                         prompt_str.size(), nullptr);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(std::format(\n            \"failed to tokenize prompt from file {}\", prompt_file_path));\n        return;\n    }\n\n    size_t num_tokens;\n    status = gptoss_context_get_num_tokens(context.get(), &num_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to get number of tokens\");\n        return;\n    }\n    if (context_length != 0) {\n        assert(context_length <= num_tokens);\n        context->num_tokens = context_length;\n    }\n    status = gptoss_context_get_num_tokens(context.get(), &num_tokens);\n    if (status != gptoss_status_success) {\n        state.SkipWithError(\"failed to get number of tokens\");\n        return;\n    }\n    // Prefill\n    for (auto _ : state) {\n        status = gptoss_context_process(context.get());\n        if (status != gptoss_status_success) {\n            state.SkipWithError(\"failed to prefill Context object\");\n            return;\n        }\n        context->num_kv_tokens = 0;\n    }\n\n    state.counters[\"tokens\"] = num_tokens;\n    state.counters[\"tokens/s\"] = benchmark::Counter(\n        state.iterations() * num_tokens, benchmark::Counter::kIsRate);\n}\n\n// Decode end-to-end benchmark\nBENCHMARK_CAPTURE(end2end_decode, gpt_oss_20b_decode, \"GPT_OSS_20B_PATH\")\n    ->UseRealTime()\n    ->Unit(benchmark::kMillisecond);\nBENCHMARK_CAPTURE(end2end_decode, gpt_oss_120b_decode, \"GPT_OSS_120B_PATH\")\n    ->UseRealTime()\n    ->Unit(benchmark::kMillisecond);\n\n// Prefill end-to-end benchmark\nBENCHMARK_CAPTURE(end2end_prefill, gpt_oss_120b_prefill_1024,\n                  \"GPT_OSS_120B_PATH\", \"GPT_OSS_PROMPT_FILE_PATH\", 1024)\n    ->UseRealTime()\n    ->Unit(benchmark::kMillisecond);\nBENCHMARK_CAPTURE(end2end_prefill, gpt_oss_20b_prefill_1024, \"GPT_OSS_20B_PATH\",\n                  \"GPT_OSS_PROMPT_FILE_PATH\", 1024)\n    ->UseRealTime()\n    ->Unit(benchmark::kMillisecond);\n\nBENCHMARK_CAPTURE(end2end_prefill, gpt_oss_120b_prefill_3072,\n                  \"GPT_OSS_120B_PATH\", \"GPT_OSS_PROMPT_FILE_PATH\", 3072)\n    ->UseRealTime()\n    ->Unit(benchmark::kMillisecond);\nBENCHMARK_CAPTURE(end2end_prefill, gpt_oss_20b_prefill_3072, \"GPT_OSS_20B_PATH\",\n                  \"GPT_OSS_PROMPT_FILE_PATH\", 3072)\n    ->UseRealTime()\n    ->Unit(benchmark::kMillisecond);\n\nBENCHMARK_MAIN();\n"
  },
  {
    "path": "gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc",
    "content": "#include <gpt-oss.h>\n#include <internal/datatype.h>\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n\n#include <cstring>\n\n#include <benchmark/benchmark.h>\n\nusing gptoss::Check;\nusing namespace gptoss::metal;\n\nconstexpr float kEpsilon = 1.0e-5f;\nconstexpr uint64_t kSeed = UINT64_C(1019827666124465388);\n\nstatic void f32_bf16w_rnsnorm(benchmark::State& state) {\n    const size_t num_tokens = 1;\n    const size_t num_channels = state.range(0);\n\n    Device device;\n    CommandQueue command_queue{device};\n    Library library{device};\n    Function f32_fill_random_fn{library, \"gptoss_f32_fill_random\"};\n    Function bf16_fill_random_fn{library, \"gptoss_bf16_fill_random\"};\n    Function f32_bf16w_rmsnorm_fn{library, \"gptoss_f32_bf16w_rmsnorm\"};\n    Buffer input_buffer{device, num_tokens * num_channels * sizeof(float)};\n    Buffer weight_buffer{device, num_channels * sizeof(gptoss_bfloat16)};\n    Buffer output_buffer{device, num_tokens * num_channels * sizeof(float)};\n    Buffer control_buffer{device, sizeof(gptoss_control)};\n    std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));\n\n    {\n        CommandBuffer command_buffer{command_queue};\n\n        size_t offset = 0;\n        Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n                command_buffer.handle(),\n                f32_fill_random_fn.handle(),\n                /*threadgroup_size=*/0,\n                /*max_threadgroups=*/10,\n                /*output_buffer=*/input_buffer.handle(),\n                /*output_offset=*/0,\n                num_channels, kSeed, offset, /*min=*/-1.0f, /*max=*/1.0),\n            \"gptoss_metal_command_buffer_encode_launch_f32_fill_random\");\n        offset += num_channels;\n\n        Check(gptoss_metal_command_buffer_encode_launch_bf16_fill_random(\n                command_buffer.handle(),\n                bf16_fill_random_fn.handle(),\n                /*threadgroup_size=*/0,\n                /*max_threadgroups=*/10,\n                /*output_buffer=*/weight_buffer.handle(),\n                /*output_offset=*/0,\n                num_channels, kSeed, offset, /*min=*/-1.0f, /*max=*/1.0),\n            \"gptoss_metal_command_buffer_encode_launch_bf16_fill_random\");\n        offset += num_channels;\n\n        command_buffer.commit();\n        command_buffer.wait_completion();\n    }\n\n    for (auto _ : state) {\n        CommandBuffer command_buffer{command_queue};\n\n        Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(\n                command_buffer.handle(),\n                f32_bf16w_rmsnorm_fn.handle(),\n                input_buffer.handle(),\n                /*input_offset=*/0,\n                weight_buffer.handle(),\n                /*weight_offset=*/0,\n                output_buffer.handle(),\n                /*output_offset=*/0,\n                control_buffer.handle(),\n                /*control_offset=*/0,\n                num_tokens,\n                num_channels,\n                kEpsilon),\n            \"gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm\");\n\n        command_buffer.commit();\n        const double elapsed_seconds = command_buffer.wait_completion();\n        state.SetIterationTime(elapsed_seconds);\n    }\n\n    const size_t num_elements = num_tokens * num_channels;\n    state.counters[\"elements\"] =\n        benchmark::Counter(state.iterations() * num_elements,\n                           benchmark::Counter::kIsRate);\n\n    const int64_t bytes_per_iteration = input_buffer.size() + weight_buffer.size() + output_buffer.size();\n    state.counters[\"bytes\"] =\n        benchmark::Counter(state.iterations() * bytes_per_iteration,\n                           benchmark::Counter::kIsRate);\n}\n\nBENCHMARK(f32_bf16w_rnsnorm)->Arg(2880)->UseManualTime()->Unit(benchmark::kMicrosecond);\n\nBENCHMARK_MAIN();\n"
  },
  {
    "path": "gpt_oss/metal/benchmark/f32-random.cc",
    "content": "#include <gpt-oss.h>\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n\n#include <benchmark/benchmark.h>\n\nusing gptoss::Check;\nusing namespace gptoss::metal;\n\nstatic void f32_fill_random(benchmark::State& state) {\n    const size_t numel = state.range(0);\n\n    Device device;\n    CommandQueue command_queue{device};\n    Library library{device};\n    Function f32_fill_random_fn{library, \"gptoss_f32_fill_random\"};\n    Buffer buffer{device, numel * sizeof(float)};\n\n    constexpr uint64_t seed = UINT64_C(1019827666124465388);\n    constexpr uint64_t offset = UINT64_C(12345678901234567890);\n    const float min = -1.0f;\n    const float max = 7.0f;\n    for (auto _ : state) {\n        CommandBuffer command_buffer{command_queue};\n\n        Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n                command_buffer.handle(),\n                f32_fill_random_fn.handle(),\n                /*threadgroup_size=*/0,\n                /*max_threadgroups=*/120,\n                /*output_buffer=*/buffer.handle(),\n                /*output_offset=*/0,\n                numel, seed, offset, min, max),\n            \"gptoss_metal_command_buffer_encode_launch_f32_fill_random\");\n\n        command_buffer.commit();\n        const double elapsed_seconds = command_buffer.wait_completion();\n        state.SetIterationTime(elapsed_seconds);\n    }\n    \n    const int64_t elements_per_iteration = numel;\n    state.counters[\"elements\"] =\n        benchmark::Counter(state.iterations() * elements_per_iteration,\n                           benchmark::Counter::kIsRate);\n\n    const int64_t bytes_per_iteration = numel * sizeof(float);\n    state.counters[\"bytes\"] =\n        benchmark::Counter(state.iterations() * bytes_per_iteration,\n                           benchmark::Counter::kIsRate);\n}\n\nconstexpr int64_t giga = INT64_C(1073741824);\nBENCHMARK(f32_fill_random)->Arg(2 * giga)->UseManualTime()->Unit(benchmark::kMicrosecond);\n\nBENCHMARK_MAIN();\n"
  },
  {
    "path": "gpt_oss/metal/benchmark/mf4-f32-convert.cc",
    "content": "#include <gpt-oss.h>\n#include <internal/datatype.h>\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n\n#include <cstring>\n\n#include <benchmark/benchmark.h>\n\nusing gptoss::Check;\nusing namespace gptoss::metal;\n\nstatic void mf4_f32_convert(benchmark::State& state) {\n    const size_t num_blocks = state.range(0);\n    const size_t num_elements = num_blocks * 32;\n    const size_t num_bytes = num_elements / 2;\n\n    Device device;\n    CommandQueue command_queue{device};\n    Library library{device};\n    Function mf4_f32_convert_fn{library, \"gptoss_mf4_f32_convert\"};\n    Buffer block_buffer{device, num_bytes};\n    Buffer scale_buffer{device, num_blocks * sizeof(gptoss_float8ue8m0)};\n    Buffer output_buffer{device, num_elements * sizeof(float)};\n\n    std::memset(block_buffer.ptr(), 0x91, num_bytes);  // force subnormals\n    std::memset(scale_buffer.ptr(), 128, num_blocks * sizeof(uint8_t));  // scale = 2.0\n\n    for (auto _ : state) {\n        CommandBuffer command_buffer{command_queue};\n\n        Check(gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(\n                command_buffer.handle(),\n                mf4_f32_convert_fn.handle(),\n                /*threadgroup_size=*/0,\n                /*max_threadgroups=*/120,\n                block_buffer.handle(),\n                scale_buffer.handle(),\n                output_buffer.handle(),\n                num_elements),\n            \"gptoss_metal_command_buffer_encode_launch_mf4_f32_convert\");\n\n        command_buffer.commit();\n        const double elapsed_seconds = command_buffer.wait_completion();\n        state.SetIterationTime(elapsed_seconds);\n    }\n\n    state.counters[\"blocks\"] =\n        benchmark::Counter(state.iterations() * num_blocks,\n                           benchmark::Counter::kIsRate);\n\n    state.counters[\"elements\"] =\n        benchmark::Counter(state.iterations() * num_elements,\n                           benchmark::Counter::kIsRate);\n\n    const int64_t bytes_per_iteration = num_bytes + num_blocks + num_elements * sizeof(float);\n    state.counters[\"bytes\"] =\n        benchmark::Counter(state.iterations() * bytes_per_iteration,\n                           benchmark::Counter::kIsRate);\n}\n\nconstexpr int64_t mega = INT64_C(1048576);\nBENCHMARK(mf4_f32_convert)->Arg(256 * mega)->UseManualTime()->Unit(benchmark::kMicrosecond);\n\nBENCHMARK_MAIN();\n"
  },
  {
    "path": "gpt_oss/metal/benchmark/u32-random.cc",
    "content": "#include <gpt-oss.h>\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n\n#include <benchmark/benchmark.h>\n\nusing gptoss::Check;\nusing namespace gptoss::metal;\n\nstatic void u32_fill_random(benchmark::State& state) {\n    const size_t numel = state.range(0);\n\n    Device device;\n    CommandQueue command_queue{device};\n    Library library{device};\n    Function u32_fill_random_fn{library, \"gptoss_u32_fill_random\"};\n    Buffer buffer{device, numel * sizeof(float)};\n\n    constexpr uint64_t seed = UINT64_C(1019827666124465388);\n    constexpr uint64_t offset = UINT64_C(12345678901234567890);\n    for (auto _ : state) {\n        CommandBuffer command_buffer{command_queue};\n\n        Check(gptoss_metal_command_buffer_encode_launch_u32_fill_random(\n                command_buffer.handle(),\n                u32_fill_random_fn.handle(),\n                /*threadgroup_size=*/0,\n                /*max_threadgroups=*/120,\n                /*output_buffer=*/buffer.handle(),\n                /*output_offset=*/0,\n                numel, seed, offset),\n            \"gptoss_metal_command_buffer_encode_launch_u32_fill_random\");\n\n        command_buffer.commit();\n        const double elapsed_seconds = command_buffer.wait_completion();\n        state.SetIterationTime(elapsed_seconds);\n    }\n    \n    const int64_t elements_per_iteration = numel;\n    state.counters[\"elements\"] =\n        benchmark::Counter(state.iterations() * elements_per_iteration,\n                           benchmark::Counter::kIsRate);\n\n    const int64_t bytes_per_iteration = numel * sizeof(float);\n    state.counters[\"bytes\"] =\n        benchmark::Counter(state.iterations() * bytes_per_iteration,\n                           benchmark::Counter::kIsRate);\n}\n\nconstexpr int64_t giga = INT64_C(1073741824);\nBENCHMARK(u32_fill_random)->Arg(2 * giga)->UseManualTime()->Unit(benchmark::kMicrosecond);\n\nBENCHMARK_MAIN();\n"
  },
  {
    "path": "gpt_oss/metal/examples/chat.py",
    "content": "#!/usr/bin/env python\n\nimport argparse\nimport sys\n\nfrom datetime import date\nfrom gpt_oss.metal import Context, Model\n\n\nDEFAULT_PROMPT = f\"\"\"You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: {date.today().isoformat()}\n\nreasoning effort high\n\n# Valid channels: analysis, final. Channel must be included for every message.\"\"\"\n\n\nparser = argparse.ArgumentParser(description=\"Chat with gpt-oss\", formatter_class=argparse.ArgumentDefaultsHelpFormatter)\nparser.add_argument(\"model\", metavar=\"PATH\", type=str, help=\"Path to gpt-oss model in Metal inference format\")\nparser.add_argument(\"--prompt\", type=str, default=DEFAULT_PROMPT, help=\"System prompt\")\nparser.add_argument(\n    \"--context-length\", type=int, default=0, help=\"The maximum context length\"\n)\nparser.add_argument(\n    \"--temperature\", type=float, default=1.0, help=\"Sampling temperature\"\n)\nparser.add_argument(\n    \"--seed\", type=int, default=0, help=\"Sampling seed\"\n)\n\n\nGREY = \"\\33[90m\"\nBOLD = \"\\33[1m\"\nRESET = \"\\33[0m\"\n\n\ndef main(args):\n    options = parser.parse_args(args)\n    model = Model(options.model)\n    tokenizer = model.tokenizer\n    start_token = tokenizer.encode_special_token(\"<|start|>\")\n    message_token = tokenizer.encode_special_token(\"<|message|>\")\n    end_token = tokenizer.encode_special_token(\"<|end|>\")\n    return_token = tokenizer.encode_special_token(\"<|return|>\")\n    channel_token = tokenizer.encode_special_token(\"<|channel|>\")\n\n    context = Context(model, context_length=options.context_length)\n    context.append(start_token)\n    context.append(\"system\")\n    context.append(message_token)\n    context.append(options.prompt)\n    context.append(end_token)\n\n    while True:\n        context.append(start_token)\n        context.append(\"user\")\n        context.append(message_token)\n        message = input(f\"{BOLD}User:{RESET} \").rstrip()\n        context.append(message)\n        context.append(end_token)\n        print(f\"{BOLD}Assistant:{RESET} {GREY}\", end=\"\", flush=True)\n        context.append(start_token)\n        context.append(\"assistant\")\n        context.append(channel_token)\n\n        inside_start_block = True\n        inside_channel_block = True\n        role = \"assistant\"\n        channel = \"\"\n        while True:\n            token = context.sample(\n                temperature=options.temperature,\n                seed=options.seed,\n            )\n            context.append(token)\n            if token == return_token:\n                print(flush=True)\n                break\n            elif token == start_token:\n                inside_start_block = True\n                role = \"\"\n                channel = \"\"\n            elif token == message_token:\n                inside_start_block = False\n                inside_channel_block = False\n                if channel == \"analysis\":\n                    print(f\"{GREY}\", end=\"\", flush=True)\n            elif token == end_token:\n                print(f\"{RESET}\", flush=True)\n            elif token == channel_token:\n                inside_channel_block = True\n            elif token < tokenizer.num_text_tokens:\n                if inside_channel_block:\n                    channel += str(tokenizer.decode(token), encoding=\"utf-8\")\n                elif inside_start_block:\n                    role += str(tokenizer.decode(token), encoding=\"utf-8\")\n                else:\n                    sys.stdout.buffer.write(tokenizer.decode(token))\n                    sys.stdout.buffer.flush()\n\n\nif __name__ == \"__main__\":\n    main(sys.argv[1:])\n"
  },
  {
    "path": "gpt_oss/metal/examples/generate.py",
    "content": "#!/usr/bin/env python\n\nimport argparse\nimport sys\n\nfrom gpt_oss.metal import Context, Model\n\n\nparser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)\nparser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint')\nparser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt')\nparser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate')\nparser.add_argument('--context-length', type=int, default=0, help='The maximum context length')\n\n\ndef main(args):\n    options = parser.parse_args(args)\n    model = Model(options.model)\n\n    context = Context(model, context_length=options.context_length)\n    context.append(options.prompt)\n    print(context.tokens)\n    prompt_tokens = context.num_tokens\n\n    tokenizer = model.tokenizer\n\n    while context.num_tokens - prompt_tokens < options.limit:\n        token = context.sample()\n        context.append(token)\n        print(str(tokenizer.decode(token), encoding=\"utf-8\"), end='', flush=True)\n\n\nif __name__ == '__main__':\n    main(sys.argv[1:])\n"
  },
  {
    "path": "gpt_oss/metal/include/gpt-oss/functions.h",
    "content": "#pragma once\n\n#include <stddef.h>\n#include <stdint.h>\n\n#include <gpt-oss/macros.h>\n#include <gpt-oss/types.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/*\n * Creates a Model object from a file in the filesystem.\n *\n * @param path Path to the file containing the model in GPT-OSS format.\n * @param model_out Pointer to the Model object that will be created. Must be released with gptoss_release_model.\n *\n * On success, returns gptoss_status_success and saves a pointer to the created Model in the model_out argument.\n * On failure, returns an error code and stores null pointer in the model_out argument.\n */\nenum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(\n    const char* path,\n    gptoss_model_t* model_out);\n\n/*\n * Query the Tokenizer object associated with the Model.\n *\n * @param model Pointer to the Model object created by gptoss_model_create_from_file.\n * @param tokenizer_out Pointer to the variable where the Tokenizer reference will be stored.\n *\n * On success, returns gptoss_status_success and stores reference to the Tokenizer object in the tokenizer_out argument.\n * On failure, returns an error code and stores NULL in the tokenizer_out argument.\n */\nenum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(\n    gptoss_model_t model,\n    gptoss_tokenizer_t* tokenizer_out);\n\n/*\n * Query the maximum context length supported by the Model.\n *\n * @param model Pointer to the Model object created by gptoss_model_create_from_file.\n * @param max_context_length_out Pointer to the variable where the maximum context length will be stored.\n *\n * On success, returns gptoss_status_success and stores maximum context length in the max_context_length_out argument.\n * On failure, returns an error code and leaves the value specified by max_context_length_out unchanged.\n */\nenum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(\n    gptoss_model_t model,\n    size_t* max_context_length_out);\n\n/*\n * Increments a Model object's reference count.\n *\n * @param model Pointer to the Model object created by gptoss_model_create_from_file.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_model_retain(\n    gptoss_model_t model);\n\n/*\n * Decrements a Model object's reference count and possibly release associated resources.\n *\n * @param model Pointer to the Model object created by gptoss_model_create_from_file.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_model_release(\n    gptoss_model_t model);\n\n/*\n * Query the token ID for a special token in the Tokenizer vocabulary.\n *\n * @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.\n * @param token_type Type of the special token to query an ID for.\n * @param token_id_out Pointer to the variable where the token ID will be stored.\n *\n * On success, returns gptoss_status_success and stores the token ID in the token_id_out argument.\n * On failure, returns an error code and leaves the value specified by token_id_out unchanged.\n */\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(\n    gptoss_tokenizer_t tokenizer,\n    enum gptoss_special_token token_type,\n    uint32_t* token_id_out);\n\n/*\n * Query the number of text tokens in the Tokenizer vocabulary.\n *\n * @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.\n * @param num_text_tokens_out Pointer to the variable where the number of text tokens will be stored.\n *\n * On success, returns gptoss_status_success and stores the number of text tokens in the num_text_tokens_out argument.\n * On failure, returns an error code and leaves the value specified by num_text_tokens_out unchanged.\n */\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(\n    gptoss_tokenizer_t tokenizer,\n    uint32_t* num_text_tokens_out);\n\n/*\n * Query the number of special tokens in the Tokenizer vocabulary.\n *\n * @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.\n * @param num_special_tokens_out Pointer to the variable where the number of special tokens will be stored.\n *\n * On success, returns gptoss_status_success and stores the number of text tokens in the num_special_tokens_out argument.\n * On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.\n */\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(\n    gptoss_tokenizer_t tokenizer,\n    uint32_t* num_special_tokens_out);\n\n/*\n * Query the total number of tokens in the Tokenizer vocabulary.\n *\n * @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.\n * @param num_tokens_out Pointer to the variable where the total number of tokens will be stored.\n *\n * On success, returns gptoss_status_success and stores the total number of tokens in the num_special_tokens_out argument.\n * On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.\n */\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(\n    gptoss_tokenizer_t tokenizer,\n    uint32_t* num_tokens_out);\n\n/*\n * Convert a text token ID to byte representation.\n *\n * @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer. The lifetime of the returned\n *                  byte representation would match the lifetime of this Tokenizer object.\n * @param token_ptr_out Pointer to the variable where the pointer to the byte representation of the token will be\n *                      stored.\n * @param token_size_out Pointer to the variable where the size of the byte representation of the token will be stored.\n *\n * On success, returns gptoss_status_success and stores pointer and size of the byte representation of the token in the\n *                     token_ptr_out and token_size_out arguments.\n * On failure, returns an error code and leaves the values specified in token_ptr_out and token_size_out unchanged.\n */\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(\n    gptoss_tokenizer_t tokenizer,\n    uint32_t token_id,\n    const void** token_ptr_out,\n    size_t* token_size_out);\n\n/*\n * Increments a Tokenizer object's reference count.\n *\n * @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(\n    gptoss_tokenizer_t tokenizer);\n\n/*\n * Decrements a Tokenizer object's reference count and possibly release associated resources.\n *\n * @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(\n    gptoss_tokenizer_t tokenizer);\n\n/*\n * Creates a Context object for use with the particular Model object.\n *\n * @param model Model object to create a context for.\n * @param context_length Maximum number of tokens in the context.\n *                       Specify 0 to use the maximum context length supported by the model.\n * @param max_batch_size Maximum number of tokens that can be processed in a single batch.\n *                       Larger values may improve prefill performance, but require more memory.\n *                       Specify 0 to use the default value.\n * @param context_out Pointer to the Context object that will be created.\n *                    Must be released with gptoss_release_context.\n *\n * On success, returns gptoss_status_success and saves a pointer to the created Context in the context_out argument.\n * On failure, returns an error code and stores null pointer in the context_out argument.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_create(\n    gptoss_model_t model,\n    size_t context_length,\n    size_t max_batch_tokens,\n    gptoss_context_t* context_out);\n\n/*\n * Query the current number of tokens cached in the Context.\n *\n * @param context Pointer to the Context object created by gptoss_context_create.\n * @param num_tokens_out Pointer to the variable where the current number of cached tokens will be stored.\n *\n * On success, returns gptoss_status_success and stores current number of cached tokens in the num_tokens_out argument.\n * On failure, returns an error code and leaves the value specified by num_tokens_out unchanged.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(\n    gptoss_context_t context,\n    size_t* num_tokens_out);\n\n/*\n * Query the maximum number of tokens cached in the Context.\n *\n * @param context Pointer to the Context object created by gptoss_context_create.\n * @param max_tokens_out Pointer to the variable where the maximum number of cached tokens will be stored.\n *\n * On success, returns gptoss_status_success and stores maximum number of cached tokens in the max_tokens_out argument.\n * On failure, returns an error code and leaves the value specified by max_tokens_out unchanged.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(\n    gptoss_context_t context,\n    size_t* max_tokens_out);\n\n/*\n * Query the list of token IDs cached in the Context.\n *\n * @param context Pointer to the Context object created by gptoss_context_create.\n * @param tokens_out Pointer to the array where up to max_tokens_out of cached tokens will be stored.\n * @param max_tokens Maximum capacity of the buffer specified by tokens_out.\n * @param num_tokens_out Pointer to the variable where the actual number of cached tokens will be stored.\n *                       This value can exceed max_tokens if the buffer capacity is insufficient.\n *\n * On success, returns gptoss_status_success and stores cached token IDs in the tokens_out argument and the number of\n * cached tokens in the num_tokens_out argument.\n * On failure, returns an error code and leaves the values specified by tokens_out and num_tokens_out unchanged.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(\n    gptoss_context_t context,\n    uint32_t* tokens_out,\n    size_t max_tokens,\n    size_t* num_tokens_out);\n\n/*\n * Tokenize and appends a character string to the Context object.\n *\n * @param context Context object created by gptoss_context_create.\n * @param text Pointer to the character string to tokenizer and append.\n * @param text_length Length of the string, in chars.\n * @param num_tokens_out Optional pointer to the variable where the number of appended tokens will be stored. Ignored if a null pointer is provided.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_append_chars(\n    gptoss_context_t context,\n    const char* text,\n    size_t text_length,\n    size_t* num_tokens_out);\n\n/*\n * Appends a list of tokens to the context.\n *\n * @param context Context object created by gptoss_context_create.\n * @param num_tokens Number of tokens to be appended.\n * @param tokens Pointer to the array of tokens to be appended.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(\n    gptoss_context_t context,\n    size_t num_tokens,\n    const uint32_t* tokens);\n\n/*\n * Resets the context, clearing its state.\n *\n * @param context Context object created by gptoss_context_create.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_reset(\n    gptoss_context_t context);\n\n/*\n * Pre-process the tokens in the Context and generate probability distribution over the next token.\n *\n * @param context Context object created by gptoss_context_create.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_process(\n    gptoss_context_t context);\n\n/*\n * Generate a token probability distribution over the next token conditioned on the Context.\n *\n * @param context Context object created by gptoss_context_create.\n * @param temperature Sampling temperature. Must be non-negative.\n * @param seed Random number generator seed to use for sampling.\n * @param token_out Pointer to the variable where the token ID will be stored.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_sample(\n    gptoss_context_t context,\n    float temperature,\n    uint64_t seed,\n    size_t max_tokens,\n    uint32_t* tokens_out,\n    size_t* num_tokens_out);\n\n/*\n * Increments a Context object's reference count.\n *\n * @param context Pointer to the Context object created by gptoss_create_context.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_retain(\n    gptoss_context_t context);\n\n/*\n * Decrements a Context object's reference count and possibly release associated resources.\n *\n * @param context Pointer to the Context object created by gptoss_create_context.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_context_release(\n    gptoss_context_t context);\n\n/*\n * Creates a Sampler object.\n *\n * @param sampler_out Pointer to the Sampler object that will be created.\n *                    Must be released with gptoss_sampler_release.\n *\n * On success, returns gptoss_status_success and saves a pointer to the created Sampler in the sampler_out argument.\n * On failure, returns an error code and stores a null pointer in the sampler_out argument.\n */\nenum gptoss_status GPTOSS_ABI gptoss_sampler_create(\n    gptoss_sampler_t* sampler_out);\n\n/*\n * Sets the sampling temperature for the Sampler.\n *\n * @param sampler Sampler object created by gptoss_sampler_create.\n * @param temperature Temperature value to be set. Must be in the [0.0, 1.0] range.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_sampler_set_temperature(\n    gptoss_sampler_t sampler,\n    float temperature);\n\n/*\n * Sets the Top-P nucleus sampling parameter for the Sampler.\n *\n * @param sampler Sampler object created by gptoss_sampler_create.\n * @param top_p Top-P value to be set. Must be in the (0.0, 1.0] range.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_sampler_set_top_p(\n    gptoss_sampler_t sampler,\n    float top_p);\n\n/*\n * Sets the presence penalty for the Sampler.\n *\n * @param sampler Sampler object created by gptoss_sampler_create.\n * @param presence_penalty Presence penalty value to be set. Must be in the [-2.0, 2.0] range.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_sampler_set_presence_penalty(\n    gptoss_sampler_t sampler,\n    float presence_penalty);\n\n/*\n * Sets the frequency penalty for the Sampler.\n *\n * @param sampler Sampler object created by gptoss_sampler_create.\n * @param frequency_penalty Frequency penalty value to be set. Must be in the [-2.0, 2.0] range.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_sampler_set_frequency_penalty(\n    gptoss_sampler_t sampler,\n    float frequency_penalty);\n\n/*\n * Increments a Sampler object's reference count.\n *\n * @param sampler Pointer to the Sampler object created by gptoss_sampler_create.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_sampler_retain(\n    gptoss_sampler_t sampler);\n\n/*\n * Decrements a Sampler object's reference count and possibly releases associated resources.\n *\n * @param sampler Pointer to the Sampler object created by gptoss_sampler_create.\n *\n * On success, returns gptoss_status_success, otherwise returns an error code.\n */\nenum gptoss_status GPTOSS_ABI gptoss_sampler_release(\n    gptoss_sampler_t sampler);\n\n#ifdef __cplusplus\n}  // extern \"C\"\n#endif\n"
  },
  {
    "path": "gpt_oss/metal/include/gpt-oss/macros.h",
    "content": "#pragma once\n\n#ifndef GPTOSS_ABI\n    #define GPTOSS_ABI\n#endif  // GPTOSS_ABI\n"
  },
  {
    "path": "gpt_oss/metal/include/gpt-oss/types.h",
    "content": "#pragma once\n\n/*\n * Status codes returned by GPT-OSS API functions.\n */\nenum gptoss_status {\n    gptoss_status_success = 0,\n    gptoss_status_invalid_argument = 1,\n    gptoss_status_unsupported_argument = 2,\n    gptoss_status_invalid_state = 3,\n    gptoss_status_io_error = 4,\n    gptoss_status_insufficient_memory = 5,\n    gptoss_status_insufficient_resources = 6,\n    gptoss_status_unsupported_system = 7,\n    gptoss_status_context_overflow = 8,\n};\n\nenum gptoss_special_token {\n    gptoss_special_token_invalid = 0,\n    gptoss_special_token_return = 1,\n    gptoss_special_token_start = 2,\n    gptoss_special_token_message = 3,\n    gptoss_special_token_end = 4,\n    gptoss_special_token_refusal = 5,\n    gptoss_special_token_constrain = 6,\n    gptoss_special_token_channel = 7,\n    gptoss_special_token_call = 8,\n    gptoss_special_token_untrusted = 9,\n    gptoss_special_token_end_untrusted = 10,\n    gptoss_special_token_max,\n};\n\n/*\n * Model object is an opaque container comprised of:\n * - Weights\n * - Temporary buffers required to run the model\n * - Any other resources requires to run the model\n */\ntypedef struct gptoss_model* gptoss_model_t;\n\ntypedef struct gptoss_tokenizer* gptoss_tokenizer_t;\n\n/*\n * Context is an opaque container comprised of:\n * - Input tokens\n * - Distribution over the output tokens\n * - KV cache\n * \n * Multiple contexts can be created and used with the same model.\n */\ntypedef struct gptoss_context* gptoss_context_t;\n\n/*\n * Sampler is an opaque container for sampling parameters:\n * - Temperature\n * - Top-p (nucleus sampling)\n * - Frequency penalty\n * - Presence penalty\n *\n * Multiple samplers can be created and used with the same context.\n */\ntypedef struct gptoss_sampler* gptoss_sampler_t;\n"
  },
  {
    "path": "gpt_oss/metal/include/gpt-oss.h",
    "content": "#pragma once\n\n#include <gpt-oss/macros.h>\n#include <gpt-oss/types.h>\n#include <gpt-oss/functions.h>\n"
  },
  {
    "path": "gpt_oss/metal/python/context.c",
    "content": "#include <Python.h>\n\n#include <gpt-oss.h>\n\n#include \"module.h\"\n\n\nstatic int PyGPTOSSContext_init(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) {\n    static char *kwlist[] = {\"model\", \"context_length\", \"max_batch_tokens\", NULL};\n    PyObject* model = NULL;\n    Py_ssize_t context_length = 0; // Default to 0 if None\n    Py_ssize_t max_batch_tokens = 0; // Default to 0 if None\n\n    if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"O!|$ii\", kwlist,\n                                     &PyGPTOSSModel_Type, &model,\n                                     &context_length, &max_batch_tokens))\n    {\n        return -1;\n    }\n    if (context_length < 0) {\n        PyErr_SetString(PyExc_ValueError, \"context_length must be a positive integer\");\n        return -1;\n    }\n    if (max_batch_tokens < 0) {\n        PyErr_SetString(PyExc_ValueError, \"max_batch_tokens must be a positive integer\");\n        return -1;\n    }\n\n    enum gptoss_status status = gptoss_context_create(\n        ((const PyGPTOSSModel*) model)->handle,\n        (size_t) context_length,\n        (size_t) max_batch_tokens,\n        &self->handle);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        goto error;\n    }\n\n    return 0;\n\nerror:\n    gptoss_context_release(self->handle);\n    self->handle = NULL;\n    return -1;\n}\n\nstatic void PyGPTOSSContext_dealloc(PyGPTOSSContext* self) {\n    (void) gptoss_context_release(self->handle);\n    self->handle = NULL;\n    PyObject_Del((PyObject*) self);\n}\n\nstatic PyObject* PyGPTOSSContext_copy(PyGPTOSSContext *self) {\n    PyGPTOSSContext* copy = (PyGPTOSSContext*) PyObject_New(PyGPTOSSContext, Py_TYPE(self));\n    if (copy == NULL) {\n        return NULL;\n    }\n\n    (void) gptoss_context_retain(self->handle);\n    copy->handle = self->handle;\n    return (PyObject*) copy;\n}\n\nstatic PyObject* PyGPTOSSContext_append(PyGPTOSSContext* self, PyObject* arg) {\n    if (PyBytes_Check(arg)) {\n        char* string_ptr = NULL;\n        Py_ssize_t string_size = 0;\n        if (PyBytes_AsStringAndSize(arg, &string_ptr, &string_size) < 0) {\n            return NULL;\n        }\n\n        const enum gptoss_status status = gptoss_context_append_chars(\n            self->handle, string_ptr, string_size, /*num_tokens_out=*/NULL);\n        if (status != gptoss_status_success) {\n            // TODO: set exception\n            return NULL;\n        }\n\n        Py_RETURN_NONE;\n    } else if (PyUnicode_Check(arg)) {\n        Py_ssize_t string_size = 0;\n        const char* string_ptr = PyUnicode_AsUTF8AndSize(arg, &string_size);\n        if (string_ptr == NULL) {\n            return NULL;\n        }\n\n        const enum gptoss_status status = gptoss_context_append_chars(\n            self->handle, string_ptr, string_size, /*num_tokens_out=*/NULL);\n        if (status != gptoss_status_success) {\n            // TODO: set exception\n            return NULL;\n        }\n\n        Py_RETURN_NONE;\n    } else if (PyLong_Check(arg)) {\n        const unsigned long token_as_ulong = PyLong_AsUnsignedLong(arg);\n        if (token_as_ulong == (unsigned long) -1 && PyErr_Occurred()) {\n            return NULL;\n        }\n\n        const uint32_t token = (uint32_t) token_as_ulong;\n        const enum gptoss_status status = gptoss_context_append_tokens(\n            self->handle, /*num_tokens=*/1, &token);\n        if (status != gptoss_status_success) {\n            // TODO: set exception\n            return NULL;\n        }\n\n        Py_RETURN_NONE;\n    } else {\n        PyErr_SetString(PyExc_TypeError, \"expected a bytes or integer argument\");\n        return NULL;\n    }\n}\n\nstatic PyObject* PyGPTOSSContext_process(PyGPTOSSContext* self) {\n    const enum gptoss_status status = gptoss_context_process(self->handle);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    Py_RETURN_NONE;\n}\n\nstatic PyObject* PyGPTOSSContext_sample(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) {\n    static char *kwlist[] = {\"max_output_tokens\", \"temperature\", \"seed\", NULL};\n    PyObject* token_list_obj = NULL;\n    uint32_t* token_ptr = NULL;\n\n    unsigned int max_output_tokens = 0;\n    unsigned long long seed = 0;\n    float temperature = 1.0f;\n    if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"I|$fK\", kwlist,\n            &max_output_tokens, &temperature, &seed))\n    {\n        return NULL;\n    }\n\n    token_ptr = (uint32_t*) PyMem_Malloc(max_output_tokens * sizeof(uint32_t));\n    if (token_ptr == NULL) {\n        goto error;\n    }\n\n    size_t num_tokens = 0;\n    const enum gptoss_status status = gptoss_context_sample(\n        self->handle, temperature, (uint64_t) seed,\n        (size_t) max_output_tokens, token_ptr, &num_tokens);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        goto error;\n    }\n\n    token_list_obj = PyList_New((Py_ssize_t) num_tokens);\n    if (token_list_obj == NULL) {\n        goto error;\n    }\n\n    for (size_t t = 0; t < num_tokens; t++) {\n        PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);\n        if (token_obj == NULL) {\n            goto error;\n        }\n\n        PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);\n    }\n    \n    PyMem_Free(token_ptr);\n    return token_list_obj;\n    \nerror:\n    PyMem_Free(token_ptr);\n    Py_XDECREF(token_list_obj);\n    return NULL;\n}\n\nstatic PyObject* PyGPTOSSContext_reset(PyGPTOSSContext* self) {\n    const enum gptoss_status status = gptoss_context_reset(self->handle);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    Py_RETURN_NONE;\n}\n\nstatic PyMethodDef PyGPTOSSContext_methods[] = {\n    {\"__copy__\", (PyCFunction) PyGPTOSSContext_copy, METH_NOARGS, \"Create a copy of the Context\"},\n    {\"append\", (PyCFunction) PyGPTOSSContext_append, METH_O, \"Append bytes to the Context\"},\n    {\"process\", (PyCFunction) PyGPTOSSContext_process, METH_NOARGS, \"Process tokens in the Context\"},\n    {\"sample\", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, \"Sample token predictions from the Context\"},\n    {\"reset\", (PyCFunction) PyGPTOSSContext_reset, METH_NOARGS, \"Discard the content of the Context\"},\n    {NULL},\n};\n\nstatic PyObject* PyGPTOSSContext_get_num_tokens(PyGPTOSSContext* self, void* closure) {\n    size_t num_tokens = 0;\n    const enum gptoss_status status = gptoss_context_get_num_tokens(self->handle, &num_tokens);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    return PyLong_FromSize_t(num_tokens);\n}\n\nstatic PyObject* PyGPTOSSContext_get_max_tokens(PyGPTOSSContext* self, void* closure) {\n    size_t max_tokens = 0;\n    const enum gptoss_status status = gptoss_context_get_max_tokens(self->handle, &max_tokens);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    return PyLong_FromSize_t(max_tokens);\n}\n\nstatic PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure) {\n    PyObject* token_list_obj = NULL;\n    uint32_t* token_ptr = NULL;\n\n    size_t num_tokens = 0;\n    gptoss_context_get_tokens(self->handle, /*tokens_out=*/NULL, /*max_tokens=*/0, &num_tokens);\n\n    if (num_tokens != 0) {\n        token_ptr = (uint32_t*) PyMem_Malloc(num_tokens * sizeof(uint32_t));\n        if (token_ptr == NULL) {\n            // TODO: set exception\n            goto error;\n        }\n\n        enum gptoss_status status = gptoss_context_get_tokens(self->handle, token_ptr, /*max_tokens=*/num_tokens, &num_tokens);\n        if (status != gptoss_status_success) {\n            // TODO: set exception\n            goto error;\n        }\n    }\n\n    token_list_obj = PyList_New((Py_ssize_t) num_tokens);\n    if (token_list_obj == NULL) {\n        goto error;\n    }\n\n    for (size_t t = 0; t < num_tokens; t++) {\n        PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);\n        if (token_obj == NULL) {\n            goto error;\n        }\n\n        PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);\n    }\n\n    PyMem_Free(token_ptr);\n    return token_list_obj;\n\nerror:\n    PyMem_Free(token_ptr);\n    Py_XDECREF(token_list_obj);\n    return NULL;\n}\n\nstatic PyGetSetDef PyGPTOSSContext_getseters[] = {\n    (PyGetSetDef) {\n        .name = \"num_tokens\",\n        .get = (getter) PyGPTOSSContext_get_num_tokens,\n        .doc = \"Current number of tokens in the context\",\n    },\n    (PyGetSetDef) {\n        .name = \"max_tokens\",\n        .get = (getter) PyGPTOSSContext_get_max_tokens,\n        .doc = \"Maximum number of tokens in the context\",\n    },\n    (PyGetSetDef) {\n        .name = \"tokens\",\n        .get = (getter) PyGPTOSSContext_get_tokens,\n        .doc = \"List of token IDs in the context\",\n    },\n    {NULL}  /* Sentinel */\n};\n\nPyTypeObject PyGPTOSSContext_Type = {\n    PyVarObject_HEAD_INIT(NULL, 0)\n    .tp_name = \"gptoss.Context\",\n    .tp_basicsize = sizeof(PyGPTOSSContext),\n    .tp_flags = 0\n        | Py_TPFLAGS_DEFAULT\n        | Py_TPFLAGS_BASETYPE,\n    .tp_doc = \"Context object\",\n    .tp_methods = PyGPTOSSContext_methods,\n    .tp_getset = PyGPTOSSContext_getseters,\n    .tp_new = PyType_GenericNew,\n    .tp_init = (initproc) PyGPTOSSContext_init,\n    .tp_dealloc = (destructor) PyGPTOSSContext_dealloc,\n};\n"
  },
  {
    "path": "gpt_oss/metal/python/model.c",
    "content": "#include <Python.h>\n\n#include <gpt-oss.h>\n\n#include \"module.h\"\n\n\nstatic int PyGPTOSSModel_init(PyGPTOSSModel* self, PyObject* args, PyObject* kwargs) {\n    enum gptoss_status status;\n    const char* filepath;\n\n    if (!PyArg_ParseTuple(args, \"s\", &filepath)) {\n        return -1;\n    }\n    status = gptoss_model_create_from_file(filepath, &self->handle);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return -1;\n    }\n    return 0;\n}\n\nstatic void PyGPTOSSModel_dealloc(PyGPTOSSModel* self) {\n    (void) gptoss_model_release(self->handle);\n    self->handle = NULL;\n    PyObject_Del((PyObject*) self);\n}\n\nstatic PyObject* PyGPTOSSModel_copy(PyGPTOSSModel* self) {\n    PyGPTOSSModel* copy = (PyGPTOSSModel*) PyObject_New(PyGPTOSSModel, Py_TYPE(self));\n    if (copy == NULL) {\n        return NULL;\n    }\n\n    (void) gptoss_model_retain(self->handle);\n    copy->handle = self->handle;\n    return (PyObject*) copy;\n}\n\nstatic PyMethodDef PyGPTOSSModel_methods[] = {\n    {\"__copy__\", (PyCFunction) PyGPTOSSModel_copy, METH_NOARGS, \"Create a copy of the Model\"},\n    {NULL},\n};\n\nstatic PyObject *PyGPTOSSModel_get_max_context_length(PyGPTOSSModel* self, void* closure) {\n    size_t max_context_length = 0;\n    const enum gptoss_status status = gptoss_model_get_max_context_length(self->handle, &max_context_length);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    return PyLong_FromSize_t(max_context_length);\n}\n\nstatic PyObject *PyGPTOSSModel_get_tokenizer(PyGPTOSSModel* self, void* closure) {\n    PyObject* args = PyTuple_Pack(1, self);\n    if (args == NULL) {\n        return NULL;\n    }\n\n    PyObject* tokenizer = PyObject_CallObject((PyObject*) &PyGPTOSSTokenizer_Type, args);\n    Py_DECREF(args);\n    return tokenizer;\n}\n\nstatic PyGetSetDef PyGPTOSSModel_getseters[] = {\n    (PyGetSetDef) {\n        .name = \"max_context_length\",\n        .get = (getter) PyGPTOSSModel_get_max_context_length,\n        .doc = \"Maximum context length supported by the model\",\n    },\n    (PyGetSetDef) {\n        .name = \"tokenizer\",\n        .get = (getter) PyGPTOSSModel_get_tokenizer,\n        .doc = \"Tokenizer object associated with the model\",\n    },\n    {NULL}  // Sentinel\n};\n\nPyTypeObject PyGPTOSSModel_Type = {\n    PyVarObject_HEAD_INIT(NULL, 0)\n    .tp_name = \"gptoss.Model\",\n    .tp_basicsize = sizeof(PyGPTOSSModel),\n    .tp_flags = 0\n        | Py_TPFLAGS_DEFAULT\n        | Py_TPFLAGS_BASETYPE,\n    .tp_doc = \"Model object\",\n    .tp_methods = PyGPTOSSModel_methods,\n    .tp_getset = PyGPTOSSModel_getseters,\n    .tp_new = PyType_GenericNew,\n    .tp_init = (initproc) PyGPTOSSModel_init,\n    .tp_dealloc = (destructor) PyGPTOSSModel_dealloc,\n};\n"
  },
  {
    "path": "gpt_oss/metal/python/module.c",
    "content": "#include <Python.h>\n\n#include \"module.h\"\n\n\nstatic PyMethodDef module_methods[] = {\n    {NULL, NULL, 0, NULL}\n};\n\nstatic PyModuleDef metal_module = {\n    PyModuleDef_HEAD_INIT,\n    \"_metal\",\n    \"Local GPT-OSS inference\",\n    -1,\n    module_methods\n};\n\nPyMODINIT_FUNC PyInit__metal(void) {\n    PyObject* module = NULL;\n    PyObject* model_type = NULL;\n    PyObject* tokenizer_type = NULL;\n    PyObject* context_type = NULL;\n\n    if (PyType_Ready(&PyGPTOSSModel_Type) < 0) {\n        goto error;\n    }\n    model_type = (PyObject*) &PyGPTOSSModel_Type;\n    Py_INCREF(model_type);\n\n    if (PyType_Ready(&PyGPTOSSTokenizer_Type) < 0) {\n        goto error;\n    }\n    tokenizer_type = (PyObject*) &PyGPTOSSTokenizer_Type;\n    Py_INCREF(tokenizer_type);\n\n    if (PyType_Ready(&PyGPTOSSContext_Type) < 0) {\n        goto error;\n    }\n    context_type = (PyObject*) &PyGPTOSSContext_Type;\n    Py_INCREF(context_type);\n\n    module = PyModule_Create(&metal_module);\n    if (module == NULL) {\n        goto error;\n    }\n\n    if (PyModule_AddObject(module, \"Model\", model_type) < 0) {\n        goto error;\n    }\n\n    if (PyModule_AddObject(module, \"Tokenizer\", tokenizer_type) < 0) {\n        goto error;\n    }\n\n    if (PyModule_AddObject(module, \"Context\", context_type) < 0) {\n        goto error;\n    }\n\n    return module;\n\nerror:\n    Py_XDECREF(context_type);\n    Py_XDECREF(tokenizer_type);\n    Py_XDECREF(model_type);\n    Py_XDECREF(module);\n    return NULL;\n}\n"
  },
  {
    "path": "gpt_oss/metal/python/module.h",
    "content": "#include <Python.h>\n\n#include <gpt-oss.h>\n\ntypedef struct {\n    PyObject_HEAD\n    gptoss_model_t handle;\n} PyGPTOSSModel;\n\ntypedef struct {\n    PyObject_HEAD\n    gptoss_tokenizer_t handle;\n} PyGPTOSSTokenizer;\n\ntypedef struct {\n    PyObject_HEAD\n    gptoss_context_t handle;\n} PyGPTOSSContext;\n\nextern PyTypeObject PyGPTOSSModel_Type;\nextern PyTypeObject PyGPTOSSTokenizer_Type;\nextern PyTypeObject PyGPTOSSContext_Type;\n"
  },
  {
    "path": "gpt_oss/metal/python/tokenizer.c",
    "content": "#include <Python.h>\n\n#include <gpt-oss.h>\n\n#include \"module.h\"\n\nstatic PyObject* PyGPTOSSTokenizer_new(PyTypeObject* subtype, PyObject* args, PyObject* kwargs) {\n    static char *kwlist[] = {\"model\", NULL};\n    PyObject* model = NULL;\n    if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"O!\", kwlist, &PyGPTOSSModel_Type, &model)) {\n        return NULL;\n    }\n\n    PyGPTOSSTokenizer* self = (PyGPTOSSTokenizer*) subtype->tp_alloc(subtype, 0);\n    if (self == NULL) {\n        return NULL;\n    }\n\n    const enum gptoss_status status = gptoss_model_get_tokenizer(\n        ((const PyGPTOSSModel*) model)->handle,\n        &self->handle);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    return (PyObject*) self;\n}\n\nstatic void PyGPTOSSTokenizer_dealloc(PyGPTOSSTokenizer* self) {\n    (void) gptoss_tokenizer_release(self->handle);\n    self->handle = NULL;\n    PyObject_Del((PyObject*) self);\n}\n\nstatic PyObject* PyGPTOSSTokenizer_copy(PyGPTOSSTokenizer* self) {\n    PyGPTOSSTokenizer* copy = (PyGPTOSSTokenizer*) PyObject_New(PyGPTOSSTokenizer, Py_TYPE(self));\n    if (copy == NULL) {\n        return NULL;\n    }\n\n    (void) gptoss_tokenizer_retain(self->handle);\n    copy->handle = self->handle;\n    return (PyObject*) copy;\n}\n\nstatic PyObject* PyGPTOSSTokenizer_encode_special_token(PyGPTOSSTokenizer* self, PyObject* arg) {\n    if (PyUnicode_Check(arg)) {\n        const char* string_ptr = PyUnicode_AsUTF8(arg);\n        if (string_ptr == NULL) {\n            return NULL;\n        }\n\n        enum gptoss_special_token token_type = gptoss_special_token_invalid;\n        if (strcmp(string_ptr, \"<|return|>\") == 0) {\n            token_type = gptoss_special_token_return;\n        } else if (strcmp(string_ptr, \"<|start|>\") == 0) {\n            token_type = gptoss_special_token_start;\n        } else if (strcmp(string_ptr, \"<|message|>\") == 0) {\n            token_type = gptoss_special_token_message;\n        } else if (strcmp(string_ptr, \"<|end|>\") == 0) {\n            token_type = gptoss_special_token_end;\n        } else if (strcmp(string_ptr, \"<|refusal|>\") == 0) {\n            token_type = gptoss_special_token_refusal;\n        } else if (strcmp(string_ptr, \"<|constrain|>\") == 0) {\n            token_type = gptoss_special_token_constrain;\n        } else if (strcmp(string_ptr, \"<|channel|>\") == 0) {\n            token_type = gptoss_special_token_channel;\n        } else if (strcmp(string_ptr, \"<|call|>\") == 0) {\n            token_type = gptoss_special_token_call;\n        } else if (strcmp(string_ptr, \"<|untrusted|>\") == 0) {\n            token_type = gptoss_special_token_untrusted;\n        } else if (strcmp(string_ptr, \"<|end_untrusted|>\") == 0) {\n            token_type = gptoss_special_token_end_untrusted;\n        } else {\n            PyErr_Format(PyExc_ValueError, \"unrecognized special token: %s\", string_ptr);\n            return NULL;\n        }\n\n        uint32_t token_id = UINT32_MAX;\n        const enum gptoss_status status = gptoss_tokenizer_get_special_token_id(\n            self->handle, token_type, &token_id);\n        if (status != gptoss_status_success || token_id == UINT32_MAX) {\n            PyErr_Format(PyExc_ValueError, \"tokenizer does not support the %s token\", string_ptr);\n            return NULL;\n        }\n\n        return PyLong_FromUnsignedLong((unsigned long) token_id);\n    } else {\n        PyErr_SetString(PyExc_TypeError, \"string argument expected\");\n        return NULL;\n    }\n}\n\nstatic PyObject* PyGPTOSSTokenizer_decode(PyGPTOSSTokenizer* self, PyObject* args, PyObject* kwargs) {\n    static char *kwlist[] = {\"token\", NULL};\n    unsigned int token = 0; // Default to 0 if None\n\n    if (!PyArg_ParseTupleAndKeywords(args, kwargs, \"I\", kwlist, &token)) {\n        return NULL;\n    }\n\n    const void* token_ptr = NULL;\n    size_t token_size = 0;\n    const enum gptoss_status status = gptoss_tokenizer_decode(self->handle, (uint32_t) token, &token_ptr, &token_size);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    return PyBytes_FromStringAndSize((const char*) token_ptr, (Py_ssize_t) token_size);\n}\n\nstatic PyMethodDef PyGPTOSSTokenizer_methods[] = {\n    {\"__copy__\", (PyCFunction) PyGPTOSSTokenizer_copy, METH_NOARGS, \"Create a copy of the Tokenizer\"},\n    {\"encode_special_token\", (PyCFunction) PyGPTOSSTokenizer_encode_special_token, METH_O, \"Query ID of a special token\"},\n    {\"decode\", (PyCFunction) PyGPTOSSTokenizer_decode, METH_VARARGS | METH_KEYWORDS, \"Convert text token ID to bytes\"},\n    {NULL},\n};\n\nstatic PyObject* PyGPTOSSTokenizer_get_num_text_tokens(PyGPTOSSTokenizer* self, void* closure) {\n    uint32_t num_text_tokens = 0;\n    const enum gptoss_status status = gptoss_tokenizer_get_num_text_tokens(self->handle, &num_text_tokens);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    return PyLong_FromUnsignedLong((unsigned long) num_text_tokens);\n}\n\nstatic PyObject* PyGPTOSSTokenizer_get_num_special_tokens(PyGPTOSSTokenizer* self, void* closure) {\n    uint32_t num_special_tokens = 0;\n    const enum gptoss_status status = gptoss_tokenizer_get_num_special_tokens(self->handle, &num_special_tokens);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    return PyLong_FromUnsignedLong((unsigned long) num_special_tokens);\n}\n\nstatic PyObject* PyGPTOSSTokenizer_get_num_tokens(PyGPTOSSTokenizer* self, void* closure) {\n    uint32_t num_tokens = 0;\n    const enum gptoss_status status = gptoss_tokenizer_get_num_tokens(self->handle, &num_tokens);\n    if (status != gptoss_status_success) {\n        // TODO: set exception\n        return NULL;\n    }\n\n    return PyLong_FromUnsignedLong((unsigned long) num_tokens);\n}\n\nstatic PyGetSetDef PyGPTOSSTokenizer_getseters[] = {\n    (PyGetSetDef) {\n        .name = \"num_tokens\",\n        .get = (getter) PyGPTOSSTokenizer_get_num_tokens,\n        .doc = \"Total number of tokens in the tokenizer dictionary\",\n    },\n    (PyGetSetDef) {\n        .name = \"num_text_tokens\",\n        .get = (getter) PyGPTOSSTokenizer_get_num_text_tokens,\n        .doc = \"Number of text tokens in the tokenizer dictionary\",\n    },\n    (PyGetSetDef) {\n        .name = \"num_special_tokens\",\n        .get = (getter) PyGPTOSSTokenizer_get_num_special_tokens,\n        .doc = \"Number of special tokens in the tokenizer dictionary\",\n    },\n    {NULL}  /* Sentinel */\n};\n\nPyTypeObject PyGPTOSSTokenizer_Type = {\n    PyVarObject_HEAD_INIT(NULL, 0)\n    .tp_name = \"gptoss.Tokenizer\",\n    .tp_basicsize = sizeof(PyGPTOSSTokenizer),\n    .tp_flags = 0\n        | Py_TPFLAGS_DEFAULT\n        | Py_TPFLAGS_BASETYPE,\n    .tp_doc = \"Tokenizer object\",\n    .tp_methods = PyGPTOSSTokenizer_methods,\n    .tp_getset = PyGPTOSSTokenizer_getseters,\n    .tp_new = PyGPTOSSTokenizer_new,\n    .tp_dealloc = (destructor) PyGPTOSSTokenizer_dealloc,\n};\n"
  },
  {
    "path": "gpt_oss/metal/scripts/create-local-model.py",
    "content": "import argparse\nimport os\nimport math\nimport sys\nimport json\nimport itertools\nimport struct\nfrom uuid import UUID\n\nimport tiktoken\n\nimport torch\nfrom safetensors import safe_open\nfrom tqdm import tqdm\nfrom openai_harmony import load_harmony_encoding, HarmonyEncodingName\n\nparser = argparse.ArgumentParser(prog='create-local-model.py', description='Convert a checkpoint directory to a local model file')\nparser.add_argument('-s', '--src', metavar='DIR', type=str, required=True, help='Path to the input checkpoint directory')\nparser.add_argument('-d', '--dst', metavar='FILE', type=str, required=True, help='Path to the output model file')\n\n\no200k_base = tiktoken.get_encoding(\"o200k_base\")\nharmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n\no200k_gptoss = tiktoken.Encoding(\n    name=\"o200k_gptoss\",\n    pat_str=o200k_base._pat_str,\n    mergeable_ranks=o200k_base._mergeable_ranks,\n    special_tokens={\n        \"<|reversed199998|>\": 199998,  # unused\n        \"<|endoftext|>\": 199999,\n        \"<|untrusted|>\": 200000,\n        \"<|endofuntrusted|>\": 200001,\n        \"<|return|>\": 200002,\n        \"<|constrain|>\": 200003,\n        \"<|reversed200004|>\": 200004,  # unused\n        \"<|channel|>\": 200005,\n        \"<|start|>\": 200006,\n        \"<|end|>\": 200007,\n        \"<|message|>\": 200008,\n        \"<|reversed200008|>\": 200008,  # unused\n        \"<|reversed200009|>\": 200009,  # unused\n        \"<|reversed200010|>\": 200010,  # unused\n        \"<|reversed200011|>\": 200011,  # unused\n        \"<|call|>\": 200012,\n        \"<|refusal|>\": 200013,\n    }\n)\n\nFILE_MAGIC = struct.pack('ccccccccccccI', b'G', b'P', b'T', b'-', b'O', b'S', b'S', b' ', b'v', b'1', b'.', b'0', 0)\nSPECIAL_TOKEN_UUID = {\n    '<|start|>': UUID('55a77c2f-8a01-4c54-8ac2-313bfc7e208d').bytes,\n    '<|message|>': UUID('16e40431-f47f-4b22-b59b-8b278fc30a54').bytes,\n    '<|end|>': UUID('fcac2f6d-4705-4f6b-b228-642accac7238').bytes,\n    '<|return|>': UUID('f799ff69-1992-43c4-a3d8-d831f475dc75').bytes,\n    '<|refusal|>': UUID('e15ba702-28c4-4292-ab8f-ffa434709128').bytes,\n    '<|constrain|>': UUID('c0bb14c7-6022-49da-ad08-792d67e8b470').bytes,\n    '<|channel|>': UUID('fd3dda11-c8ab-4033-876e-d93deb172c93').bytes,\n    '<|call|>': UUID('1220f796-e388-4de5-b487-fe2eb5fe03c0').bytes,\n    '<|untrusted|>': UUID('07d7da55-b346-4cff-8b37-7cefacf8a3e8').bytes,\n    '<|end_untrusted|>': UUID('f265bd9c-c717-469e-a447-920687d65d90').bytes,\n}\n\nINCLUDE_SPECIAL_TOKENS = [\n    \"<|start|>\",\n    \"<|message|>\",\n    \"<|end|>\",\n    \"<|return|>\",\n    \"<|refusal|>\",\n    \"<|constrain|>\",\n    \"<|channel|>\",\n    \"<|call|>\",\n    \"<|untrusted|>\",\n    \"<|end_untrusted|>\",\n]\n\nGPTOSS_MODEL_UUID = UUID('df52dc86-1789-4ed0-a295-66f10508145b').bytes\nAPPLE_GPU_LAYOUT_UUID = UUID('229177a8-5775-4268-bfd8-d588b351c56d').bytes\nTIKTOKEN_TOKENIZER_UUID = UUID('7401aded-2a95-40cb-b782-9ccebaafe72b').bytes\n\nUE8_OFFSET = 14  # bias to MXFP4 block scales\n\ndef write_file_header(f):\n    f.write(FILE_MAGIC)\n\ndef write_tokenizer_header(f,\n                           num_special_tokens: int,\n                           num_text_tokens: int,\n                           regex_size: int,\n                           tokens_size: int):\n    f.write(TIKTOKEN_TOKENIZER_UUID)\n    f.write(struct.pack('<I', num_special_tokens))\n    f.write(struct.pack('<I', num_text_tokens))\n    f.write(struct.pack('<I', regex_size))\n    f.write(struct.pack('<I', tokens_size))\n\ndef write_model_header(f,\n                       context_length : int,\n                       num_blocks : int,\n                       num_experts : int,\n                       num_active_experts : int,\n                       embedding_dim : int,\n                       mlp_dim : int,\n                       swiglu_limit : float,\n                       head_dim: int,\n                       num_heads : int,\n                       num_kv_heads : int,\n                       attention_window : int,\n                       rope_theta : float,\n                       interpolation_scale : float,\n                       yarn_offset : float,\n                       yarn_scale : float,\n                       yarn_multiplier : float,\n                       rmsnorm_epsilon : float):\n    f.write(GPTOSS_MODEL_UUID)\n    f.write(struct.pack('<I', context_length))\n    f.write(struct.pack('<I', num_blocks))\n    f.write(struct.pack('<I', num_experts))\n    f.write(struct.pack('<I', num_active_experts))\n    f.write(struct.pack('<I', embedding_dim))\n    f.write(struct.pack('<I', mlp_dim))\n    f.write(struct.pack('<f', swiglu_limit))\n    f.write(struct.pack('<I', head_dim))\n    f.write(struct.pack('<I', num_heads))\n    f.write(struct.pack('<I', num_kv_heads))\n    f.write(struct.pack('<I', attention_window))\n    f.write(struct.pack('<f', rope_theta))\n    f.write(struct.pack('<f', interpolation_scale))\n    f.write(struct.pack('<f', yarn_offset))\n    f.write(struct.pack('<f', yarn_scale))\n    f.write(struct.pack('<f', yarn_multiplier))\n    f.write(struct.pack('<f', rmsnorm_epsilon))\n    f.write(APPLE_GPU_LAYOUT_UUID)\n\n\ndef write_padding(out_file, alignment_multiple=16384):\n    offset = out_file.tell()\n    alignment_size = -offset % alignment_multiple\n    if alignment_size != 0:\n        alignment = bytes(alignment_size)\n        out_file.write(alignment)\n\n\ndef write_embedding_weight(out_file, weight):\n    write_padding(out_file, alignment_multiple=16)\n\n    assert weight.dtype == torch.float8_e4m3fn or weight.dtype == torch.bfloat16\n    out_file.write(weight.view(torch.uint8).numpy().tobytes())\n\n\ndef write_rmsnorm_gain(out_file, gain):\n    write_padding(out_file, alignment_multiple=16)\n\n    assert gain.dtype == torch.bfloat16\n    out_file.write(gain.view(torch.uint8).numpy().tobytes())\n\n\ndef write_attn_sink(out_file, sink):\n    write_padding(out_file, alignment_multiple=16)\n\n    assert sink.dtype == torch.bfloat16\n    out_file.write(sink.view(torch.uint8).numpy().tobytes())\n\n\ndef write_linear_weight(out_file, *args):\n    write_padding(out_file, alignment_multiple=16)\n\n    for t in args:\n        out_file.write(t.view(torch.uint8).numpy().tobytes())\n\n\ndef main(args):\n    options = parser.parse_args(args)\n\n    with open(os.path.join(options.src, \"config.json\"), \"r\") as f:\n        config = json.load(f)\n\n    num_blocks = config[\"num_hidden_layers\"]\n    num_experts = config[\"num_experts\"]\n    num_active_experts = 4\n    num_q_heads = config[\"num_attention_heads\"]\n    num_kv_heads = config[\"num_key_value_heads\"]\n    head_dim = config[\"head_dim\"]\n    embedding_dim = config[\"hidden_size\"]\n    mlp_dim = config[\"intermediate_size\"]\n    swiglu_limit = config.get(\"swiglu_limit\", 7.0)\n    rope_theta = config[\"rope_theta\"]\n    attention_window = config[\"sliding_window\"]\n    initial_context_length = config[\"initial_context_length\"]\n    rope_scaling_factor = config[\"rope_scaling_factor\"]\n    rope_ntk_alpha = config[\"rope_ntk_alpha\"]\n    rope_ntk_beta = config[\"rope_ntk_beta\"]\n\n    tokens_size = 0\n    num_text_tokens = 0\n    # First add all text tokens\n    for t in range(o200k_gptoss.n_vocab):\n        if not harmony_encoding.is_special_token(t):\n            token_bytes = o200k_gptoss.decode_single_token_bytes(t)\n            assert len(token_bytes) > 0\n            tokens_size += len(token_bytes) + 2  # uint16_t string length + string data\n            num_text_tokens += 1\n    # Then add all special tokens\n    num_included_tokens = 200013 + 1\n    print(f\"Tokenizer: {num_included_tokens} tokens\")\n\n    # Read from all files ending with .safetensors in the checkpoint directory\n    safetensor_files = [\n        os.path.join(options.src, fname)\n        for fname in os.listdir(options.src)\n        if fname.endswith(\".safetensors\")\n    ]\n    # Build a mapping from tensor name to filepath\n    tensor_name_to_file = {}\n    for safetensor_file in safetensor_files:\n        with safe_open(safetensor_file, framework=\"pt\", device=\"cpu\") as src:\n            for key in src.keys():\n                tensor_name_to_file[key] = safetensor_file\n\n    def get_tensor(name):\n        with safe_open(tensor_name_to_file[name], framework=\"pt\", device=\"cpu\") as src:\n            return src.get_tensor(name)\n\n    with open(options.dst, \"wb\") as dst:\n        write_file_header(dst)\n\n        yarn_low = (\n            head_dim / 2\n            * math.log(initial_context_length / (rope_ntk_beta * 2 * math.pi))\n            / math.log(rope_theta)\n        )\n        yarn_high = (\n            head_dim / 2\n            * math.log(initial_context_length / (rope_ntk_alpha * 2 * math.pi))\n            / math.log(rope_theta)\n        )\n\n        write_model_header(dst,\n                            context_length=int(initial_context_length * rope_scaling_factor),\n                            num_blocks=num_blocks,\n                            num_experts=num_experts,\n                            num_active_experts=num_active_experts,\n                            embedding_dim=embedding_dim,\n                            mlp_dim=mlp_dim,\n                            swiglu_limit=swiglu_limit,\n                            head_dim=head_dim,\n                            num_heads=num_q_heads,\n                            num_kv_heads=num_kv_heads,\n                            attention_window=attention_window,\n                            rope_theta=rope_theta,\n                            interpolation_scale=1.0 / rope_scaling_factor,\n                            yarn_offset=-yarn_low / (yarn_high - yarn_low),\n                            yarn_scale=1.0 / (yarn_high - yarn_low),\n                            yarn_multiplier=0.1 * math.log(rope_scaling_factor) + 1.0,\n                            rmsnorm_epsilon=1.0e-5)\n\n        write_tokenizer_header(dst,\n                                num_special_tokens=num_included_tokens - num_text_tokens,\n                                num_text_tokens=num_text_tokens,\n                                regex_size=len(o200k_gptoss._pat_str.encode(\"ascii\")) + 1,\n                                tokens_size=tokens_size)\n\n        ### Tokenizer\n        # Special tokens\n        for token_idx in range(num_text_tokens, num_included_tokens):\n            token = o200k_gptoss.decode_single_token_bytes(token_idx).decode('ascii')\n            if token in INCLUDE_SPECIAL_TOKENS:\n                dst.write(SPECIAL_TOKEN_UUID[token])\n            else:\n                dst.write(bytes(16))\n        # Regex\n        dst.write(o200k_gptoss._pat_str.encode(\"ascii\"))\n        dst.write(struct.pack('B', 0))\n        # Text tokens\n        tokenizer_bytes_written = 0\n        for t in range(num_text_tokens):\n            token_bytes = o200k_gptoss.decode_single_token_bytes(t)\n            assert len(token_bytes) > 0\n            dst.write(struct.pack('<H', len(token_bytes)))\n            dst.write(token_bytes)\n            tokenizer_bytes_written += len(token_bytes) + 2\n        assert(tokenizer_bytes_written == tokens_size), (tokenizer_bytes_written, tokens_size)\n        write_padding(dst)\n\n        embedding_weight = get_tensor(\"embedding.weight\")\n        # Filter out unused tokens\n        embedding_weight = embedding_weight[:num_included_tokens, :]\n        write_embedding_weight(dst, embedding_weight)\n\n        for n in tqdm(range(num_blocks)):\n            write_rmsnorm_gain(dst, get_tensor(f\"block.{n}.attn.norm.scale\"))\n\n            attn_qkv_weight = get_tensor(f\"block.{n}.attn.qkv.weight\")\n            attn_qkv_bias = get_tensor(f\"block.{n}.attn.qkv.bias\")\n            for qkv in (attn_qkv_weight, attn_qkv_bias):\n                qk = qkv[:head_dim * (num_q_heads + num_kv_heads), ...].contiguous()\n                v = qkv[head_dim * (num_q_heads + num_kv_heads):, ...].contiguous()\n                qk = qk.view(num_q_heads + num_kv_heads, 2, head_dim // 2, -1).transpose(1, 2).reshape(num_q_heads + num_kv_heads, head_dim, -1)\n                q = qk[:num_q_heads, ...]\n                k = qk[num_q_heads:, ...]\n                # Factor multiplication by 1/sqrt(64) = 0.125 = 0.5 * 0.25 in SDPA into Q and K projections\n                assert head_dim == 64\n                q *= 0.5\n                k *= 0.25\n                v = v.view(num_kv_heads, head_dim, -1)\n                qkv.copy_(torch.cat((q, k, v), dim=0).reshape(*qkv.shape))\n\n            write_linear_weight(dst, attn_qkv_weight, attn_qkv_bias)\n\n            write_attn_sink(dst, get_tensor(f\"block.{n}.attn.sinks\"))\n\n            write_linear_weight(dst, get_tensor(f\"block.{n}.attn.out.weight\"), get_tensor(f\"block.{n}.attn.out.bias\"))\n\n            write_rmsnorm_gain(dst, get_tensor(f\"block.{n}.mlp.norm.scale\"))\n\n            write_linear_weight(dst, get_tensor(f\"block.{n}.mlp.gate.weight\"), get_tensor(f\"block.{n}.mlp.gate.bias\"))\n\n        write_rmsnorm_gain(dst, get_tensor(\"norm.scale\"))\n\n        unembedding_weight = get_tensor(\"unembedding.weight\")\n        unembedding_weight = unembedding_weight[:num_included_tokens, :]\n        write_linear_weight(dst, unembedding_weight)\n\n        for n in tqdm(range(num_blocks)):\n            mlp1_blocks = get_tensor(f\"block.{n}.mlp.mlp1_weight.blocks\")\n            mlp1_scales = get_tensor(f\"block.{n}.mlp.mlp1_weight.scales\")\n            assert mlp1_scales.min().item() < 254 - UE8_OFFSET\n            mlp1_bias = get_tensor(f\"block.{n}.mlp.mlp1_bias\")\n\n            mlp2_blocks = get_tensor(f\"block.{n}.mlp.mlp2_weight.blocks\")\n            mlp2_scales = get_tensor(f\"block.{n}.mlp.mlp2_weight.scales\")\n            assert mlp2_scales.min().item() < 254 - UE8_OFFSET\n            mlp2_bias = get_tensor(f\"block.{n}.mlp.mlp2_bias\")\n\n            # Write MoE weights grouped by expert\n            write_padding(dst)\n\n            for e in range(num_experts):\n                write_padding(dst, alignment_multiple=16)                    \n                dst.write(mlp1_blocks[e, ...].view(torch.uint8).numpy().tobytes())\n\n                write_padding(dst, alignment_multiple=16)\n                dst.write((mlp1_scales + UE8_OFFSET)[e, ...].view(torch.uint8).numpy().tobytes())\n\n                write_padding(dst, alignment_multiple=16)\n                dst.write(mlp1_bias[e, ...].view(torch.uint8).numpy().tobytes())\n\n                write_padding(dst, alignment_multiple=16)                    \n                dst.write(mlp2_blocks[e, ...].view(torch.uint8).numpy().tobytes())\n\n                write_padding(dst, alignment_multiple=16)\n                dst.write((mlp2_scales + UE8_OFFSET)[e, ...].view(torch.uint8).numpy().tobytes())\n\n                write_padding(dst, alignment_multiple=16)\n                dst.write(mlp2_bias[e, ...].view(torch.uint8).numpy().tobytes())\n\nif __name__ == \"__main__\":\n    main(sys.argv[1:])\n"
  },
  {
    "path": "gpt_oss/metal/source/accumulate.metal",
    "content": "#include <metal_integer>\n#include <metal_math>\n\n#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n\nkernel void gptoss_f32_accumulate_e4(\n    constant gptoss_accumulate_args& args [[ buffer(0) ]],\n    const device float4* input [[ buffer(1) ]],\n    const device gptoss_expert_prediction* expert [[ buffer(2) ]],\n    device float4* output [[ buffer(3) ]],\n    const device gptoss_control* control [[ buffer(4) ]],\n    uint2 gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_index_in_threadgroup]],\n    uint2 threadgroup_size [[ threads_per_threadgroup ]])\n{\n    const uint num_active_experts = 4;\n    if (control->abort != 0) {\n        return;\n    }\n\n    const uint num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;\n    const uint threadgroup_start = gid.x * num_vecs_per_threadgroup;\n    const uint num_vecs = args.num_vecs;\n    const uint threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, num_vecs);\n    const uint thread_start = threadgroup_start + tid;\n    uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size.x - 1)) / threadgroup_size.x);\n\n    const uint num_vecs_per_expert = args.num_vecs_per_expert;\n    const float scale0 = expert[gid.y * num_active_experts + 0].score;\n    const device float4* input0 = input + gid.y * num_vecs + thread_start;\n    const float scale1 = expert[gid.y * num_active_experts + 1].score;\n    const device float4* input1 = input0 + num_vecs_per_expert;\n    const float scale2 = expert[gid.y * num_active_experts + 2].score;\n    const device float4* input2 = input1 + num_vecs_per_expert;\n    const float scale3 = expert[gid.y * num_active_experts + 3].score;\n    const device float4* input3 = input2 + num_vecs_per_expert;\n    output += gid.y * num_vecs + thread_start;\n    for (; num_iter != 0; num_iter--) {\n        float4 acc = *output;\n        const float4 val0 = *input0;\n        const float4 val1 = *input1;\n        const float4 val2 = *input2;\n        const float4 val3 = *input3;\n        input0 += threadgroup_size.x;\n        acc = metal::fma(val0, scale0, acc);\n        input1 += threadgroup_size.x;\n        acc = metal::fma(val1, scale1, acc);\n        input2 += threadgroup_size.x;\n        acc = metal::fma(val2, scale2, acc);\n        input3 += threadgroup_size.x;\n        acc = metal::fma(val3, scale3, acc);\n        *output = acc;\n        output += threadgroup_size.x;\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/context.c",
    "content": "#include <assert.h>\n#include <float.h>\n#include <inttypes.h>\n#include <stdbool.h>\n#include <stdint.h>\n#include <stdlib.h>\n#include <string.h>\n\n#include <gpt-oss.h>\n\n#include \"internal/datatype.h\"\n#include \"internal/model.h\"\n#include \"internal/metal.h\"\n#include \"internal/metal-kernels.h\"\n#include \"internal/log.h\"\n#include \"internal/rng.h\"\n\n\nenum gptoss_status GPTOSS_ABI gptoss_context_create(\n    gptoss_model_t model,\n    size_t context_length,\n    size_t max_batch_tokens,\n    gptoss_context_t* context_out)\n{\n    *context_out = NULL;\n\n    enum gptoss_status status = gptoss_status_success;\n    struct gptoss_context* context = NULL;\n\n    // Validate context_length\n    if (context_length == 0) {\n        context_length = model->context_length;\n    } else if (context_length > model->context_length) {\n        GPTOSS_LOG_ERROR(\"requested context length %zu exceeds model context length %\" PRIu32,\n            context_length, model->context_length);\n        status = gptoss_status_invalid_argument;\n        goto cleanup;\n    }\n    assert(context_length != 0);\n    assert(context_length <= model->context_length);\n\n    // Validate max_batch_tokens\n    if (max_batch_tokens == 0) {\n        max_batch_tokens = GPTOSS_DEFAULT_BATCH_SIZE;\n    } else if (max_batch_tokens > context_length) {\n        GPTOSS_LOG_ERROR(\"requested max batch tokens %zu exceeds context length %zu\",\n            max_batch_tokens, context_length);\n        status = gptoss_status_invalid_argument;\n        goto cleanup;\n    }\n    assert(max_batch_tokens != 0);\n    assert(max_batch_tokens <= context_length);\n\n    context = malloc(sizeof(struct gptoss_context));\n    if (context == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to allocate %zu bytes for Context object\",\n            sizeof(struct gptoss_context));\n        status = gptoss_status_insufficient_memory;\n        goto cleanup;\n    }\n    memset(context, 0, sizeof(struct gptoss_context));\n\n    atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed);\n    context->max_tokens = context_length;\n    context->max_batch_tokens = max_batch_tokens;\n\n    // Activation buffers\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->residual_activation_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->rmsnorm_activation_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &context->qkv_activation_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &context->sdpa_activation_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(float), NULL, &context->gate_activation_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &context->expert_activation_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    // The last entry will hold the total number of tokens.\n    status = gptoss_metal_buffer_create(&model->device, (1 + model->num_experts) * sizeof(uint32_t), NULL, &context->expert_offset_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * sizeof(uint32_t), NULL, &context->token_to_expert_routing_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->swiglu_input_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->moe_activation_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n\n    // Input/output buffers\n    status = gptoss_metal_buffer_create(&model->device, sizeof(struct gptoss_control), NULL, &context->control_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->score_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->prob_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->max_threadgroups * sizeof(float), NULL, &context->sum_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * sizeof(uint64_t), NULL, &context->argmax_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_buffer_create(&model->device, model->num_blocks * context_length * 2 * model->num_kv_heads * model->head_dim * sizeof(float), NULL, &context->kvcache_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n\n    context->kvcache_size = context->kvcache_buffer.size;\n    context->allocation_size = \n        context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size +\n        context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size +\n        context->gate_activation_buffer.size + context->expert_activation_buffer.size +\n        context->expert_offset_buffer.size + context->token_to_expert_routing_buffer.size + context->swiglu_input_buffer.size +\n        context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +\n        context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;\n\n    context->model = model;\n    gptoss_model_retain(model);\n    *context_out = context;\n    context = NULL;\n\ncleanup:\n    gptoss_context_release(context);\n    return status;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(\n    gptoss_context_t context,\n    size_t* num_tokens_out)\n{\n    *num_tokens_out = context->num_tokens;\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(\n    gptoss_context_t context,\n    size_t* max_tokens_out)\n{\n    *max_tokens_out = context->max_tokens;\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(\n    gptoss_context_t context,\n    uint32_t* tokens_out,\n    size_t max_tokens,\n    size_t* num_tokens_out)\n{\n    *num_tokens_out = context->num_tokens;\n    if (max_tokens < context->num_tokens) {\n        return gptoss_status_insufficient_memory;\n    }\n\n    if (context->num_tokens != 0) {\n        memcpy(tokens_out, context->token_buffer.ptr, context->num_tokens * sizeof(uint32_t));\n    }\n    return gptoss_status_success;\n}\n\n// Prefill: input_tokens_offset = number of tokens in KV cache, num_input_tokens > 0, num_output_tokens = 0.\n// Sampling: input_tokens_offset = number of tokens in the context - 1, num_input_tokens = 1, num_output_tokens = 1.\n// Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens.\nstatic enum gptoss_status process_tokens(\n    gptoss_context_t context,\n    struct gptoss_metal_command_buffer* command_buffer,\n    size_t input_tokens_offset,\n    size_t num_input_tokens,\n    size_t num_output_tokens)\n{\n    assert(num_input_tokens != 0);\n    assert(num_input_tokens <= context->max_batch_tokens);\n    assert(num_output_tokens <= context->max_batch_tokens);\n    assert(num_input_tokens >= num_output_tokens);\n    const size_t min_tokens_for_dense_matmul_kernels = 64;\n    const size_t min_tokens_for_dense_moe_kernels = 64;\n\n    enum gptoss_status status = gptoss_status_success;\n    const struct gptoss_model* model = context->model;\n\n    const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);\n\n    const size_t input_tokens_end = input_tokens_offset + num_input_tokens;\n    for (size_t input_batch_start = input_tokens_offset;\n        input_batch_start < input_tokens_end;\n        input_batch_start += context->max_batch_tokens)\n    {\n        const size_t input_batch_size = math_min(context->max_batch_tokens, input_tokens_end - input_batch_start);\n        const size_t input_batch_end = input_batch_start + input_batch_size;\n        const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end);\n\n        status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(\n            command_buffer,\n            &model->bf16_f32_embeddings_fn,\n            model->embeddings_threadgroup_size,\n            &context->token_buffer,\n            input_batch_start * sizeof(uint32_t),\n            &model->shared_weight_buffer,\n            /*weight_offset=*/0,\n            &context->residual_activation_buffer,\n            /*output_offset=*/0,\n            &context->control_buffer,\n            /*control_offset=*/0,\n            /*num_tokens=*/input_batch_size,\n            /*num_channels=*/model->embedding_dim);\n        if (status != gptoss_status_success) {\n            GPTOSS_LOG_ERROR(\"failed to encode bf16_f32_embeddings kernel launch\");\n            return status;\n        }\n        for (uint32_t n = 0; n < model->num_blocks; n++) {\n            const bool last_block = n + 1 == model->num_blocks;\n            const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size;\n\n            status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(\n                command_buffer,\n                &model->f32_bf16w_rmsnorm_fn,\n                &context->residual_activation_buffer,\n                /*input_offset=*/0,\n                &model->shared_weight_buffer,\n                /*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,\n                &context->rmsnorm_activation_buffer,\n                /*output_offset=*/0,\n                &context->control_buffer,\n                /*control_offset=*/0,\n                /*num_tokens=*/input_batch_size,\n                /*num_channels=*/model->embedding_dim,\n                model->rmsnorm_epsilon);\n            if (status != gptoss_status_success) {\n                GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_rmsnorm kernel launch\");\n                return status;\n            }\n\n            if (input_batch_size >= min_tokens_for_dense_matmul_kernels) {\n                status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(\n                    command_buffer,\n                    &model->f32_bf16w_dense_matmul_qkv_fn,\n                    &context->rmsnorm_activation_buffer,\n                    /*input_offset=*/0,\n                    &model->shared_weight_buffer,\n                    /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,\n                    &model->shared_weight_buffer,\n                    /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,\n                    &context->qkv_activation_buffer,\n                    /*output_offset=*/0,\n                    &context->kvcache_buffer,\n                    /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),\n                    &context->control_buffer,\n                    /*control_offset=*/0,\n                    /*num_tokens=*/input_batch_size,\n                    /*num_cols=*/model->embedding_dim,\n                    /*num_rows=*/attn_qkv_dim,\n                    /*max_tokens=*/context->max_tokens,\n                    /*token_offset=*/input_batch_start);\n                if (status != gptoss_status_success) {\n                    GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul_qkv kernel launch\");\n                    return status;\n                }\n\n                status = gptoss_metal_command_buffer_encode_launch_f32_rope(\n                    command_buffer,\n                    &model->f32_rope_fn,\n                    /*threadgroup_size=*/32,\n                    &context->qkv_activation_buffer,\n                    /*input_offset=*/0,\n\n                    &context->kvcache_buffer,\n                    /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),\n                    &context->control_buffer,\n                    /*control_offset=*/0,\n                    model->rope_theta,\n                    model->interpolation_scale,\n                    model->yarn_offset,\n                    model->yarn_scale,\n                    model->yarn_multiplier,\n                    input_batch_size,\n                    model->num_heads,\n                    model->num_kv_heads,\n                    model->head_dim,\n                    /*max_tokens=*/context->max_tokens,\n                    /*token_offset=*/input_batch_start);\n                if (status != gptoss_status_success) {\n                    GPTOSS_LOG_ERROR(\"failed to encode f32_rope kernel launch\");\n                    return status;\n                }\n\n            } else {\n                status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(\n                    command_buffer,\n                    &model->f32_bf16w_matmul_qkv_fn,\n                    model->attn_qkv_threadgroup_size,\n                    &context->rmsnorm_activation_buffer,\n                    /*input_offset=*/0,\n                    &model->shared_weight_buffer,\n                    /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,\n                    &model->shared_weight_buffer,\n                    /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,\n                    &context->qkv_activation_buffer,\n                    /*output_offset=*/0,\n                    &context->kvcache_buffer,\n                    /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),\n                    &context->control_buffer,\n                    /*control_offset=*/0,\n                    /*num_tokens=*/input_batch_size,\n                    /*num_cols=*/model->embedding_dim,\n                    /*num_q_heads=*/model->num_heads,\n                    /*num_kv_heads=*/model->num_kv_heads,\n                    /*attn_head_dim=*/model->head_dim,\n                    /*token_offset=*/input_batch_start,\n                    /*max_tokens=*/context->max_tokens,\n                    /*rope_base=*/model->rope_theta,\n                    /*interpolation_scale=*/model->interpolation_scale,\n                    /*yarn_offset=*/model->yarn_offset,\n                    /*yarn_scale=*/model->yarn_scale,\n                    /*yarn_multiplier=*/model->yarn_multiplier);\n                if (status != gptoss_status_success) {\n                    GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_qkv kernel launch\");\n                    return status;\n                }\n            }\n\n            if (num_block_output_tokens != 0) {\n                status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(\n                    command_buffer,\n                    &model->f32_sdpa_q8_d64_fn,\n                    &context->qkv_activation_buffer,\n                    /*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),\n                    &context->kvcache_buffer,\n                    /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),\n                    &model->shared_weight_buffer,\n                    /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,\n                    &context->sdpa_activation_buffer,\n                    /*output_offset=*/0,\n                    &context->control_buffer,\n                    /*control_offset=*/0,\n                    /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,\n                    /*kv_stride=*/2 * context->max_tokens * model->head_dim,\n                    num_block_output_tokens,\n                    input_batch_start + input_batch_size - num_block_output_tokens,\n                    model->num_heads, model->num_kv_heads, model->head_dim);\n                if (status != gptoss_status_success) {\n                    GPTOSS_LOG_ERROR(\"failed to encode f32_sdpa kernel launch\");\n                    return status;\n                }\n\n                if (input_batch_size >= min_tokens_for_dense_matmul_kernels) {\n                    status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(\n                        command_buffer,\n                        &model->f32_bf16w_dense_matmul_attn_output_fn,\n                        &context->sdpa_activation_buffer,\n                        /*input_offset=*/0,\n                        &model->shared_weight_buffer,\n                        /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,\n                        &model->shared_weight_buffer,\n                        /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,\n                        &context->residual_activation_buffer,\n                        /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),\n                        &context->control_buffer,\n                        /*control_offset=*/0,\n                        /*num_tokens=*/num_block_output_tokens,\n                        /*num_cols=*/model->num_heads * model->head_dim,\n                        /*num_rows=*/model->embedding_dim);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul_attn_output kernel launch\");\n                        return status;\n                    }\n                } else {\n                    status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(\n                        command_buffer,\n                        &model->f32_bf16w_matmul_fn,\n                        model->attn_out_threadgroup_size,\n                        &context->sdpa_activation_buffer,\n                        /*input_offset=*/0,\n                        &model->shared_weight_buffer,\n                        /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,\n                        &model->shared_weight_buffer,\n                        /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,\n                        &context->residual_activation_buffer,\n                        /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),\n                        &context->control_buffer,\n                        /*control_offset=*/0,\n                        /*num_tokens=*/num_block_output_tokens,\n                        /*num_cols=*/model->num_heads * model->head_dim,\n                        /*num_rows=*/model->embedding_dim);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_add kernel launch\");\n                        return status;\n                    }\n                }\n                status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(\n                    command_buffer,\n                    &model->f32_bf16w_rmsnorm_fn,\n                    &context->residual_activation_buffer,\n                    /*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),\n                    &model->shared_weight_buffer,\n                    /*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,\n                    &context->rmsnorm_activation_buffer,\n                    /*output_offset=*/0,\n                    &context->control_buffer,\n                    /*control_offset=*/0,\n                    num_block_output_tokens,\n                    model->embedding_dim,\n                    model->rmsnorm_epsilon);\n                if (status != gptoss_status_success) {\n                    GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_rmsnorm kernel launch\");\n                    return status;\n                }\n                if (input_batch_size >= min_tokens_for_dense_matmul_kernels) {\n                    status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(\n                        command_buffer,\n                        &model->f32_bf16w_dense_matmul_mlp_gate_fn,\n                        &context->rmsnorm_activation_buffer,\n                        /*input_offset=*/0,\n                        &model->shared_weight_buffer,\n                        /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,\n                        &model->shared_weight_buffer,\n                        /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,\n                        &context->gate_activation_buffer,\n                        /*output_offset=*/0,\n                        &context->control_buffer,\n                        /*control_offset=*/0,\n                        num_block_output_tokens,\n                        model->embedding_dim,\n                        model->num_experts);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul_mlp_gate kernel launch\");\n                        return status;\n                    }\n                } else {\n                    status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(\n                        command_buffer,\n                        &model->f32_bf16w_matmul_fn,\n                        model->mlp_gate_threadgroup_size,\n                        &context->rmsnorm_activation_buffer,\n                        /*input_offset=*/0,\n                        &model->shared_weight_buffer,\n                        /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,\n                        &model->shared_weight_buffer,\n                        /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,\n                        &context->gate_activation_buffer,\n                        /*output_offset=*/0,\n                        &context->control_buffer,\n                        /*control_offset=*/0,\n                        /*num_tokens=*/num_block_output_tokens,\n                        /*num_cols=*/model->embedding_dim,\n                        /*num_rows=*/model->num_experts);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul kernel launch\");\n                        return status;\n                    }\n                }\n\n                const char* kernel_name = NULL;\n                switch (model->num_experts) {\n                    case 32:\n                        kernel_name = \"f32_topk_softmax_e32_k4_fn\";\n                        status = gptoss_metal_command_buffer_encode_launch_f32_topk(\n                            command_buffer,\n                            &model->f32_topk_softmax_e32_k4_fn,\n                            &context->gate_activation_buffer, /*input_offset=*/0,\n                            &context->expert_activation_buffer, /*output_offset=*/0,\n                            &context->control_buffer, /*control_offset=*/0,\n                            num_block_output_tokens,\n                            model->num_experts,\n                            model->num_active_experts);\n                        break;\n                    case 128:\n                        kernel_name = \"f32_topk_softmax_e128_k4_fn\";\n                        status = gptoss_metal_command_buffer_encode_launch_f32_topk(\n                            command_buffer,\n                            &model->f32_topk_softmax_e128_k4_fn,\n                            &context->gate_activation_buffer, /*input_offset=*/0,\n                            &context->expert_activation_buffer, /*output_offset=*/0,\n                            &context->control_buffer, /*control_offset=*/0,\n                            num_block_output_tokens,\n                            model->num_experts,\n                            model->num_active_experts);\n                        break;\n                    default:\n                        status = gptoss_status_unsupported_argument;\n                        GPTOSS_LOG_ERROR(\"missing Top-K kernel for %\" PRIu32 \" experts\", model->num_experts);\n                        return status;\n                }\n                if (status != gptoss_status_success) {\n                    GPTOSS_LOG_ERROR(\"failed to encode %s kernel launch\", kernel_name);\n                    return status;\n                }\n\n                // If we have enough tokens in prefill, we will pick the prefill-optimized kernels.\n                if (num_block_output_tokens >= min_tokens_for_dense_moe_kernels) {\n                    status = gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(\n                        command_buffer,\n                        &model->f32_expert_routing_metadata_fn,\n                        &context->expert_activation_buffer,\n                        /*expert_predictions_offset=*/0,\n                        &context->expert_offset_buffer,\n                        /*expert_offsets_offset=*/0,\n                        &context->token_to_expert_routing_buffer,\n                        /*intra_expert_offsets_offset=*/0,\n                        num_block_output_tokens * model->num_active_experts,\n                        model->num_experts);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_expert_routing_metadata kernel launch\");\n                        return status;\n                    }\n                    status = gptoss_metal_command_buffer_encode_launch_f32_scatter(\n                        command_buffer,\n                        &model->f32_scatter_e4_fn,\n                        &context->rmsnorm_activation_buffer,\n                        /*input_offset=*/0,\n                        &context->expert_activation_buffer,\n                        /*expert_predictions_offset=*/0,\n                        &context->expert_offset_buffer,\n                        /*expert_offsets_offset=*/0,\n                        &context->token_to_expert_routing_buffer,\n                        /*intra_expert_offsets_offset=*/0,\n                        &context->swiglu_input_buffer,\n                        /*output_offset=*/0,\n                        model->embedding_dim,\n                        num_block_output_tokens,\n                        model->num_active_experts);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_scatter kernel launch\");\n                        return status;\n                    } \n                    // Dense MoE SwiGLU matmul.\n                    status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(\n                        command_buffer,\n                        &model->f32_mf4w_moe_dense_matmul_swiglu_fn,\n                        &context->expert_offset_buffer,\n                        /*expert_offsets_offset=*/0,\n                        &context->swiglu_input_buffer,\n                        /*input_offset=*/0,\n                        &model->block_weight_buffers[n],\n                        /*weight_block_offset=*/0,\n                        &model->block_weight_buffers[n],\n                        /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,\n                        &model->block_weight_buffers[n],\n                        /*bias_offset=*/model->mlp_swiglu_bias_offset,\n                        &context->swiglu_activation_buffer,\n                        /*output_offset=*/0,\n                        model->swiglu_limit,\n                        /*expert_stride_bytes=*/model->per_expert_block_weight_size,\n                        num_block_output_tokens,\n                        model->num_experts,\n                        model->embedding_dim,\n                        2 * model->mlp_dim);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch\");\n                        return status;\n                    }\n\n                    // Dense MoE proj matmul.\n                    status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(\n                        command_buffer,\n                        &model->f32_mf4w_moe_dense_matmul_fn,\n                        &context->expert_offset_buffer,\n                        /*expert_offsets_offset=*/0,\n                        &context->swiglu_activation_buffer,\n                        /*input_offset=*/0,\n                        &model->block_weight_buffers[n],\n                        /*weight_block_offset=*/model->mlp_out_block_offset,\n                        &model->block_weight_buffers[n],\n                        /*weight_scale_offset=*/model->mlp_out_scale_offset,\n                        &model->block_weight_buffers[n],\n                        /*bias_offset=*/model->mlp_out_bias_offset,\n                        &context->moe_activation_buffer,\n                        /*output_offset=*/0,\n                        /*expert_stride_bytes=*/model->per_expert_block_weight_size,\n                        num_block_output_tokens,\n                        model->num_experts,\n                        model->mlp_dim,\n                        model->embedding_dim);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch\");\n                        return status;\n                    }\n                    // Gather and accumulate.\n                    status = gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(\n                        command_buffer,\n                        &model->f32_gather_and_accumulate_e4_fn,\n                        &context->moe_activation_buffer,\n                        /*input_offset=*/0,\n                        &context->expert_activation_buffer,\n                        /*expert_predictions_offset=*/0,\n                        &context->expert_offset_buffer,\n                        /*expert_offsets_offset=*/0,\n                        &context->token_to_expert_routing_buffer,\n                        /*intra_expert_offsets_offset=*/0,\n                        &context->residual_activation_buffer, \n                        /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),\n                        model->embedding_dim,\n                        num_block_output_tokens,\n                        model->num_active_experts);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_gather_and_accumulate_e4 kernel launch\");\n                        return status;\n                    }\n\n                } else {\n                    status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(\n                        command_buffer,\n                        &model->f32_mf4w_moe_matmul_swiglu_fn,\n                        model->mlp_swiglu_threadgroup_size,\n                        &context->rmsnorm_activation_buffer,\n                        /*input_offset=*/0,\n                        &context->expert_activation_buffer,\n                        /*expert_offset=*/0,\n                        &model->block_weight_buffers[n],\n                        /*weight_block_offset=*/0,\n                        &model->block_weight_buffers[n],\n                        /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,\n                        &model->block_weight_buffers[n],\n                        /*bias_offset=*/model->mlp_swiglu_bias_offset,\n                        &context->swiglu_activation_buffer,\n                        /*output_offset=*/0,\n                        &context->control_buffer,\n                        /*control_offset=*/0,\n                        model->swiglu_limit,\n                        model->per_expert_block_weight_size,\n                        num_block_output_tokens,\n                        model->num_active_experts,\n                        model->embedding_dim,\n                        model->mlp_dim);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul_swiglu kernel launch\");\n                        return status;\n                    }\n\n                    status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(\n                        command_buffer,\n                        &model->f32_mf4w_moe_matmul_fn,\n                        model->mlp_out_threadgroup_size,\n                        &context->swiglu_activation_buffer,\n                        /*input_offset=*/0,\n                        &context->expert_activation_buffer,\n                        /*expert_offset=*/0,\n                        &model->block_weight_buffers[n],\n                        /*weight_block_offset=*/model->mlp_out_block_offset,\n                        &model->block_weight_buffers[n],\n                        /*weight_scale_offset=*/model->mlp_out_scale_offset,\n                        &model->block_weight_buffers[n],\n                        /*bias_offset=*/model->mlp_out_bias_offset,\n                        &context->moe_activation_buffer,\n                        /*output_offset=*/0,\n                        &context->control_buffer,\n                        /*control_offset=*/0,\n                        model->per_expert_block_weight_size,\n                        num_block_output_tokens,\n                        model->num_active_experts,\n                        model->mlp_dim,\n                        model->embedding_dim);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul kernel launch\");\n                        return status;\n                    }\n\n                    status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(\n                        command_buffer,\n                        &model->f32_accumulate_e4_fn,\n                        model->mlp_acc_threadgroup_size,\n                        model->max_threadgroups,\n                        &context->moe_activation_buffer,\n                        /*input_offset=*/0,\n                        &context->expert_activation_buffer,\n                        /*expert_offset=*/0,\n                        &context->residual_activation_buffer,\n                        /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),\n                        &context->control_buffer,\n                        /*control_offset=*/0,\n                        model->embedding_dim,\n                        num_block_output_tokens,\n                        model->num_active_experts);\n                    if (status != gptoss_status_success) {\n                        GPTOSS_LOG_ERROR(\"failed to encode f32_accumulate kernel launch\");\n                        return status;\n                    }\n                }\n            }\n        }\n\n        if (output_batch_size != 0) {\n            status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(\n                command_buffer,\n                &model->f32_bf16w_rmsnorm_fn,\n                &context->residual_activation_buffer,\n                /*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float),\n                &model->shared_weight_buffer,\n                /*weight_offset=*/model->rmsnorm_weight_offset,\n                &context->rmsnorm_activation_buffer,\n                /*output_offset=*/0,\n                &context->control_buffer,\n                /*control_offset=*/0,\n                /*num_tokens=*/output_batch_size,\n                /*num_channels=*/model->embedding_dim,\n                model->rmsnorm_epsilon);\n            if (status != gptoss_status_success) {\n                GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_rmsnorm kernel launch\");\n                return status;\n            }\n\n            status = gptoss_metal_command_buffer_encode_fill_buffer(\n                command_buffer,\n                &context->argmax_buffer,\n                /*offset=*/0,\n                /*size=*/sizeof(uint64_t) * output_batch_size,\n                /*fill_value=*/0xFF);\n            if (status != gptoss_status_success) {\n                GPTOSS_LOG_ERROR(\"failed to encode fill buffer command\");\n                return status;\n            }\n\n            status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(\n                command_buffer,\n                &model->f32_bf16w_unembedding_fn,\n                model->unembedding_threadgroup_size,\n                model->max_threadgroups,\n                &context->rmsnorm_activation_buffer,\n                /*input_offset=*/0,\n                &model->shared_weight_buffer,\n                /*weight_offset=*/model->unembedding_weight_offset,\n                &context->score_buffer,\n                /*output_offset=*/0,\n                &context->argmax_buffer,\n                /*argmax_offset=*/0,\n                &context->control_buffer,\n                /*control_offset=*/0,\n                /*num_tokens=*/output_batch_size,\n                /*num_cols=*/model->embedding_dim,\n                /*num_rows=*/model->vocabulary_size);\n            if (status != gptoss_status_success) {\n                GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_unembedding kernel launch\");\n                return status;\n            }\n        }\n    }\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_append_chars(\n    gptoss_context_t context,\n    const char* text,\n    size_t text_length,\n    size_t* num_tokens_out)\n{\n    enum gptoss_status status = gptoss_status_success;\n    const struct gptoss_model* model = context->model;\n    const struct gptoss_tokenizer* tokenizer = model->tokenizer;\n    size_t num_appended_tokens = 0;\n    while (text_length != 0) {\n        if (context->num_tokens == context->max_tokens) {\n            status = gptoss_status_context_overflow;\n            break;\n        }\n        const char* tokens = tokenizer->tokens_ptr;\n        uint32_t best_token = UINT32_MAX;\n        uint32_t best_token_length = 0;\n        for (size_t t = 0; t < tokenizer->num_text_tokens; t++) {\n            uint16_t token_length;\n            memcpy(&token_length, tokens, sizeof(uint16_t));\n            tokens += sizeof(uint16_t);\n            if (token_length <= text_length && token_length > best_token_length) {\n                if (memcmp(text, tokens, token_length) == 0) {\n                    if (token_length > best_token_length) {\n                        best_token = (uint32_t) t;\n                        best_token_length = token_length;\n                    }\n                }\n            }\n            tokens += token_length;\n        }\n\n        if (best_token == UINT32_MAX) {\n            GPTOSS_LOG_ERROR(\"failed to tokenize text \\\"%.*s\\\"\", (int) text_length, text);\n            return gptoss_status_invalid_argument;\n        }\n\n        uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;\n        if (context->num_kv_tokens > context->num_tokens) {\n            if (input_tokens[context->num_tokens] != best_token) {\n                input_tokens[context->num_tokens] = best_token;\n\n                // Invalidate the KV cache starting with the newly added token.\n                context->num_kv_tokens = context->num_tokens;\n            }\n            context->num_tokens++;\n        } else {\n            input_tokens[context->num_tokens++] = best_token;\n        }\n        num_appended_tokens++;\n        text += best_token_length;\n        text_length -= best_token_length;\n    }\n    if (num_tokens_out != NULL) {\n        *num_tokens_out = num_appended_tokens;\n    }\n    return status;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(\n    gptoss_context_t context,\n    size_t num_tokens,\n    const uint32_t* tokens)\n{\n    const struct gptoss_model* model = context->model;\n\n    // Validate all tokens\n    for (size_t t = 0; t < num_tokens; t++) {\n        const uint32_t token = tokens[t];\n        if (token >= model->vocabulary_size) {\n            GPTOSS_LOG_ERROR(\"token %\" PRIu32 \" at index %zu is out of bounds for vocabulary size %\" PRIu32,\n                token, t, context->model->vocabulary_size);\n            return gptoss_status_invalid_argument;\n        }\n    }\n\n    enum gptoss_status status = gptoss_status_success;\n    uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;\n    while (num_tokens != 0) {\n        if (context->num_tokens == context->max_tokens) {\n            status = gptoss_status_context_overflow;\n            break;\n        }\n\n        if (context->num_kv_tokens > context->num_tokens) {\n            const size_t num_tokens_to_verify = math_min(context->num_kv_tokens - context->num_tokens, num_tokens);\n            size_t num_verified_tokens = 0;\n            for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) {\n                if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) {\n                    // Invalidate the KV cache starting with the newly added tokens.\n                    context->num_kv_tokens = context->num_tokens + num_verified_tokens;\n                    break;\n                }\n            }\n\n            context->num_tokens += num_verified_tokens;\n            tokens += num_verified_tokens;\n            num_tokens -= num_verified_tokens;\n        } else {\n            const size_t num_tokens_to_copy = math_min(context->max_tokens - context->num_tokens, num_tokens);\n            memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t));\n            context->num_tokens += num_tokens_to_copy;\n            tokens += num_tokens_to_copy;\n            num_tokens -= num_tokens_to_copy;\n        }\n    }\n\n    return status;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_process(\n    gptoss_context_t context)\n{\n    if (context->num_tokens > context->num_kv_tokens) {\n        struct gptoss_metal_command_buffer command_buffer = {0};\n\n        enum gptoss_status status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);\n        if (status != gptoss_status_success) {\n            goto cleanup;\n        }\n\n        struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;\n        control->abort = 0;\n\n        status = process_tokens(\n            context,\n            &command_buffer,\n            /*input_tokens_offset=*/context->num_kv_tokens,\n            /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,\n            /*num_output_tokens=*/0);\n        if (status != gptoss_status_success) {\n            goto cleanup;\n        }\n\n        status = gptoss_metal_command_buffer_commit(&command_buffer);\n        if (status != gptoss_status_success) {\n            goto cleanup;\n        }\n\n        status = gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);\n        if (status != gptoss_status_success) {\n            goto cleanup;\n        }\n\n        context->num_kv_tokens = context->num_tokens;\n\ncleanup:\n        gptoss_metal_command_buffer_release(&command_buffer);\n        return status;\n    }\n    \n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_sample(\n    gptoss_context_t context,\n    float temperature,\n    uint64_t seed,\n    size_t max_tokens,\n    uint32_t* tokens_out,\n    size_t* num_tokens_out)\n{\n    enum gptoss_status status = gptoss_status_success;\n    const struct gptoss_model* model = context->model;\n    struct gptoss_metal_command_buffer command_buffer = {0};\n\n    *num_tokens_out = 0;\n\n    const uint32_t num_original_tokens = context->num_tokens;\n\n    status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n\n    struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;\n    control->abort = 0;\n\n    for (size_t t = 0; t < max_tokens; t++) {\n        if (context->num_kv_tokens < context->num_tokens) {\n            status = process_tokens(\n                context,\n                &command_buffer,\n                /*input_tokens_offset=*/context->num_kv_tokens,\n                /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,\n                /*num_output_tokens=*/1);\n            context->num_kv_tokens = context->num_tokens;\n        } else {\n            status = process_tokens(\n                context,\n                &command_buffer,\n                /*input_tokens_offset=*/context->num_tokens - 1,\n                /*num_input_tokens=*/1,\n                /*num_output_tokens=*/1);\n        }\n        if (status != gptoss_status_success) {\n            goto cleanup;\n        }\n\n        if (temperature != 0.0f) {\n            assert(context->num_processed_tokens != 0);\n            uint32_t num_threadgroups = 0;\n            uint32_t num_dims_per_threadgroup = 0;\n            status = gptoss_metal_command_buffer_encode_launch_f32_softmax(\n                &command_buffer,\n                &model->f32_softmax_fn,\n                /*threadgroup_size=*/512,\n                model->max_threadgroups,\n                &context->score_buffer,\n                /*score_offset=*/0,\n                &context->argmax_buffer,\n                /*argmax_offset=*/0,\n                &context->prob_buffer,\n                /*prob_offset=*/0,\n                &context->sum_buffer,\n                /*sum_offset=*/0,\n                &context->control_buffer,\n                /*control_offset=*/0,\n                model->vocabulary_size,\n                /*num_tokens=*/1,\n                temperature,\n                &num_threadgroups,\n                &num_dims_per_threadgroup);\n            if (status != gptoss_status_success) {\n                GPTOSS_LOG_ERROR(\"failed to encode f32_softmax kernel launch\");\n                goto cleanup;\n            }\n\n            status = gptoss_metal_command_buffer_encode_launch_f32_sample(\n                &command_buffer,\n                &model->f32_sample_fn,\n                /*min_threadgroup_size=*/512,\n                &context->prob_buffer,\n                /*prob_offset=*/0,\n                &context->sum_buffer,\n                /*sum_offset=*/0,\n                &context->token_buffer,\n                /*token_offset=*/context->num_tokens * sizeof(uint32_t),\n                &context->control_buffer,\n                /*control_offset=*/0,\n                /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),\n                /*rng_offset=*/context->num_tokens,\n                /*num_blocks=*/num_threadgroups,\n                /*num_channels=*/model->vocabulary_size,\n                /*num_channels_per_block=*/num_dims_per_threadgroup);\n            if (status != gptoss_status_success) {\n                GPTOSS_LOG_ERROR(\"failed to encode f32_sample kernel launch\");\n                goto cleanup;\n            }\n        } else {\n            status = gptoss_metal_command_buffer_encode_copy_buffer(\n                &command_buffer,\n                &context->argmax_buffer,\n                /*input_offset=*/0,\n                &context->token_buffer,\n                /*output_offset=*/context->num_tokens * sizeof(uint32_t),\n                /*size=*/sizeof(uint32_t));\n            if (status != gptoss_status_success) {\n                GPTOSS_LOG_ERROR(\"failed to encode copy buffer\");\n                goto cleanup;\n            }\n        }\n        context->num_tokens += 1;\n        context->num_kv_tokens = context->num_tokens;\n    }\n\n    gptoss_metal_command_buffer_commit(&command_buffer);\n    gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);\n\n    const uint32_t* token_ptr = (const uint32_t*) context->token_buffer.ptr;\n    const uint32_t num_generated_tokens = context->num_tokens - num_original_tokens;\n    memcpy(tokens_out, token_ptr + num_original_tokens, num_generated_tokens * sizeof(uint32_t));\n    *num_tokens_out = num_generated_tokens;\n\ncleanup:\n    gptoss_metal_command_buffer_release(&command_buffer);\n    return status;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_reset(\n    gptoss_context_t context)\n{\n    context->num_tokens = 0;\n\n    // Note: context->num_kv_tokens is not reset and context->input_tokens_buffer is not cleared.\n    // If the subsequently added tokens match the tokens already in the KV cache, we reuse the KV cache.\n\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_retain(\n    gptoss_context_t context)\n{\n    atomic_fetch_add_explicit(&context->ref_count, 1, memory_order_relaxed);\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_context_release(\n    gptoss_context_t context)\n{\n    if (context != NULL) {\n        if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) {\n            // Activation buffers\n            gptoss_metal_buffer_release(&context->residual_activation_buffer);\n            gptoss_metal_buffer_release(&context->rmsnorm_activation_buffer);\n            gptoss_metal_buffer_release(&context->qkv_activation_buffer);\n            gptoss_metal_buffer_release(&context->sdpa_activation_buffer);\n            gptoss_metal_buffer_release(&context->gate_activation_buffer);\n            gptoss_metal_buffer_release(&context->expert_activation_buffer);\n            gptoss_metal_buffer_release(&context->swiglu_activation_buffer);\n            gptoss_metal_buffer_release(&context->moe_activation_buffer);\n            gptoss_metal_buffer_release(&context->expert_offset_buffer);\n            gptoss_metal_buffer_release(&context->token_to_expert_routing_buffer);\n            gptoss_metal_buffer_release(&context->swiglu_input_buffer);\n\n            // Input/output buffers\n            gptoss_metal_buffer_release(&context->control_buffer);\n            gptoss_metal_buffer_release(&context->token_buffer);\n            gptoss_metal_buffer_release(&context->score_buffer);\n            gptoss_metal_buffer_release(&context->prob_buffer);\n            gptoss_metal_buffer_release(&context->sum_buffer);\n            gptoss_metal_buffer_release(&context->argmax_buffer);\n            gptoss_metal_buffer_release(&context->kvcache_buffer);\n\n            gptoss_model_release(context->model);\n\n            memset(context, 0, sizeof(struct gptoss_context));\n            free(context);\n        }\n    }\n    return gptoss_status_success;\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/convert.metal",
    "content": "#include <metal_integer>\n\n#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n\nkernel void gptoss_mf4_f32_convert(\n    constant gptoss_convert_args& args [[ buffer(0) ]],\n    const device uint4* blocks [[ buffer(1) ]],\n    const device uchar* scales [[ buffer(2) ]],\n    device float4* output [[ buffer(3) ]],\n    uint gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_position_in_threadgroup]],\n    uint threadgroup_size [[ threads_per_threadgroup ]])\n{\n    const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;\n    const ulong threadgroup_start = gid * num_vecs_per_threadgroup;\n    const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);\n    const ulong thread_start = threadgroup_start + tid;\n    uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);\n\n    blocks += thread_start;\n    scales += thread_start;\n    output += 8 * thread_start;\n    for (; num_iter != 0; num_iter--) {\n        const uint4 block = *blocks;\n        const float scale = as_type<float>((static_cast<uint>(*scales) + 14) << 23);\n        uint4 block02468ACEGIKMOQSU = block + block;\n        uint4 block13579BDFHJLNPRTV = block >> 3;\n        block02468ACEGIKMOQSU &= 0x1E1E1E1Eu;\n        block13579BDFHJLNPRTV &= 0x1E1E1E1Eu;\n        block02468ACEGIKMOQSU += 0x70707070u;\n        block13579BDFHJLNPRTV += 0x70707070u;\n        block02468ACEGIKMOQSU &= 0x8E8E8E8Eu;\n        block13579BDFHJLNPRTV &= 0x8E8E8E8Eu;\n        const uint4 block26AEIMQU = block02468ACEGIKMOQSU & 0xFF00FF00u;\n        const uint4 block048CGKOS = (block02468ACEGIKMOQSU << 8) & 0xFF00FF00u;\n        const uint4 block37BFJNRV = block13579BDFHJLNPRTV & 0xFF00FF00u;\n        const uint4 block159DHLPT = (block13579BDFHJLNPRTV << 8) & 0xFF00FF00u;\n        const float4 block048C = static_cast<float4>(as_type<half4>(block048CGKOS.xy)) * scale;\n        const float4 blockGKOS = static_cast<float4>(as_type<half4>(block048CGKOS.zw)) * scale;\n        const float4 block26AE = static_cast<float4>(as_type<half4>(block26AEIMQU.xy)) * scale;\n        const float4 blockIMQU = static_cast<float4>(as_type<half4>(block26AEIMQU.zw)) * scale;\n        const float4 block159D = static_cast<float4>(as_type<half4>(block159DHLPT.xy)) * scale;\n        const float4 blockHLPT = static_cast<float4>(as_type<half4>(block159DHLPT.zw)) * scale;\n        const float4 block37BF = static_cast<float4>(as_type<half4>(block37BFJNRV.xy)) * scale;\n        const float4 blockJNRV = static_cast<float4>(as_type<half4>(block37BFJNRV.zw)) * scale;\n\n        output[0] = (float4) { block048C.x, block159D.x, block26AE.x, block37BF.x };\n        output[1] = (float4) { block048C.y, block159D.y, block26AE.y, block37BF.y };\n        output[2] = (float4) { block048C.z, block159D.z, block26AE.z, block37BF.z };\n        output[3] = (float4) { block048C.w, block159D.w, block26AE.w, block37BF.w };\n        output[4] = (float4) { blockGKOS.x, blockHLPT.x, blockIMQU.x, blockJNRV.x };\n        output[5] = (float4) { blockGKOS.y, blockHLPT.y, blockIMQU.y, blockJNRV.y };\n        output[6] = (float4) { blockGKOS.z, blockHLPT.z, blockIMQU.z, blockJNRV.z };\n        output[7] = (float4) { blockGKOS.w, blockHLPT.w, blockIMQU.w, blockJNRV.w };\n\n        blocks += threadgroup_size;\n        scales += threadgroup_size;\n        output += 8 * threadgroup_size;\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/embeddings.metal",
    "content": "#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n\nkernel void gptoss_bf16_f32_embeddings(\n    constant gptoss_embeddings_args& args [[ buffer(0) ]],\n    const device uint* tokens [[ buffer(1) ]],\n    const device bfloat4* weights [[ buffer(2) ]],\n    device float4* output [[ buffer(3) ]],\n    const device gptoss_control* control [[ buffer(4) ]],\n    uint gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_position_in_threadgroup]],\n    uint threadgroup_size [[ threads_per_threadgroup ]])\n{\n    if (control->abort != 0) {\n        return;\n    }\n\n    const uint t = tokens[gid];\n\n    weights += t * args.num_vecs;\n    output += gid * args.num_vecs;\n    for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {\n        const bfloat4 w = weights[i];\n        output[i] = static_cast<float4>(w);\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/expert_routing_metadata.metal",
    "content": "#include <internal/kernel-args.h>\n#include <metal_integer>\n#include <metal_math>\n#include <metal_stdlib>\n\nconstant uint kMaxExperts = 128;\n\nkernel void gptoss_f32_expert_routing_metadata(\n    constant gptoss_expert_routing_metadata_args& args [[ buffer(0) ]],\n    const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(1) ]],\n    device uint* __restrict__ expert_offsets [[ buffer(2) ]],\n    device uint* __restrict__ intra_expert_offsets [[ buffer(3) ]],\n    uint tg_size [[threads_per_threadgroup]],\n    uint tid [[thread_position_in_threadgroup]]) \n{\n    assert(args.num_experts <= kMaxExperts);\n    // Create threadgroup mem and initialize it to 0.\n    threadgroup metal::atomic_uint tg_counts[kMaxExperts];\n    for (uint e = tid; e < args.num_experts; e += tg_size) {\n        metal::atomic_store_explicit(&tg_counts[e], 0u, metal::memory_order_relaxed);\n    }\n\n    threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n\n    for (uint i = tid; i < args.tokens; i += tg_size) {\n        const uint e = expert_predictions[i].expert_id;\n        const uint r = metal::atomic_fetch_add_explicit(&tg_counts[e], 1u, metal::memory_order_relaxed);\n        intra_expert_offsets[i] = r;\n    }\n    threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n\n    if (tid == 0) {\n        uint total = 0;\n        for (uint e = 0; e < args.num_experts; ++e) {\n            const uint bin = metal::atomic_load_explicit(&tg_counts[e], metal::memory_order_relaxed);\n            expert_offsets[e] = total;\n            total += bin;\n        }\n        expert_offsets[args.num_experts] = total;\n    }\n}"
  },
  {
    "path": "gpt_oss/metal/source/gather_and_accumulate.metal",
    "content": "#include <internal/kernel-args.h>\n#include <metal_integer>\n#include <metal_math>\n#include <metal_stdlib>\n\n// TODO(ibrahim): This is not optimal as each thread only gathers and accumulates a single float4. To amortize the\n// cost of reading the expert, offset and scales for a token, we should let each thread gather and accumulate several\n// float4s.\nkernel void gptoss_f32_gather_and_accumulate_e4(\n    constant gptoss_gather_args& args [[ buffer(0) ]],\n    const device float* in [[ buffer(1) ]],\n    const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(2) ]],\n    const device uint* expert_offsets [[ buffer(3) ]],\n    const device uint* intra_expert_offsets [[ buffer(4) ]],\n    device float* out [[ buffer(5) ]],\n    uint3 gid [[thread_position_in_grid]]) \n{\n    const uint T = args.tokens;\n    const uint k = args.active_experts_per_token;\n    const uint D = args.token_stride;\n\n    assert((D & 3u) == 0);\n    assert(k == 4);\n\n    const uint row = gid.y;\n    if (row >= T) {\n        return;\n    }\n\n    const uint col_vec4 = gid.x;\n    const uint col = col_vec4 * 4u;\n    if (col >= D) {\n        return;\n    }\n\n    device float4* dst4 = reinterpret_cast<device float4*>(out + row * D + col);\n\n    const uint base = row * k;\n    const gptoss_expert_prediction expert0 = expert_predictions[base];\n    const gptoss_expert_prediction expert1 = expert_predictions[base + 1];\n    const gptoss_expert_prediction expert2 = expert_predictions[base + 2];\n    const gptoss_expert_prediction expert3 = expert_predictions[base + 3];\n    const uint expert0_id = expert0.expert_id;\n    const uint expert1_id = expert1.expert_id;\n    const uint expert2_id = expert2.expert_id;\n    const uint expert3_id = expert3.expert_id;\n    const float scale0 = expert0.score;\n    const float scale1 = expert1.score;\n    const float scale2 = expert2.score;\n    const float scale3 = expert3.score;\n    const uint4 current_intra_expert_offsets =\n        *reinterpret_cast<const device uint4*>(&intra_expert_offsets[base]);\n    // Get the row indices for the current expert ids\n    const uint r0 = expert_offsets[expert0_id] + current_intra_expert_offsets.x;\n    const uint r1 = expert_offsets[expert1_id] + current_intra_expert_offsets.y;\n    const uint r2 = expert_offsets[expert2_id] + current_intra_expert_offsets.z;\n    const uint r3 = expert_offsets[expert3_id] + current_intra_expert_offsets.w;\n\n    const device float4* src0 =\n        reinterpret_cast<const device float4*>(in + r0 * D + col);\n    const device float4* src1 =\n        reinterpret_cast<const device float4*>(in + r1 * D + col);\n    const device float4* src2 =\n        reinterpret_cast<const device float4*>(in + r2 * D + col);\n    const device float4* src3 =\n        reinterpret_cast<const device float4*>(in + r3 * D + col);\n\n    float4 acc = *dst4;\n    acc = metal::fma(*src0, scale0, acc);\n    acc = metal::fma(*src1, scale1, acc);\n    acc = metal::fma(*src2, scale2, acc);\n    acc = metal::fma(*src3, scale3, acc);\n    *dst4 = acc;\n}"
  },
  {
    "path": "gpt_oss/metal/source/generate.c",
    "content": "#include <assert.h>\n#include <inttypes.h>\n#include <math.h>\n#include <signal.h>\n#include <stdatomic.h>\n#include <stdbool.h>\n#include <stdio.h>\n#include <stdint.h>\n#include <stdlib.h>\n#include <string.h>\n\n#include <mach/mach_time.h>\n\n#include <gpt-oss.h>\n\n#include \"internal/model.h\"\n\nstruct {\n    atomic_uint_least64_t inference_bytes;\n    atomic_size_t num_prefill_tokens;\n    atomic_uint_least64_t prefill_microseconds;\n    atomic_size_t num_generated_tokens;\n    atomic_uint_least64_t generation_microseconds;\n} globals = {\n    .inference_bytes = 0,\n    .num_prefill_tokens = 0,\n    .prefill_microseconds = 0,\n    .num_generated_tokens = 0,\n    .generation_microseconds = 0,\n};\n\nstruct options {\n    const char* model;\n    const char* prompt;\n    size_t context_length;\n    size_t max_tokens;\n    float temperature;\n    bool verbose;\n};\n\nstatic inline double mach_timestamp_diff_to_seconds(uint64_t start_timestamp, uint64_t end_timestamp) {\n    static mach_timebase_info_data_t timebase_info = {0};\n    if (timebase_info.denom == 0) {\n        mach_timebase_info(&timebase_info);\n    }\n    const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;\n    return ((double) elapsed_mach_time * (double) timebase_info.numer) / ((double) timebase_info.denom * 1.0e+9);\n}\n\nstatic inline uint64_t mach_timestamp_diff_to_microseconds(uint64_t start_timestamp, uint64_t end_timestamp) {\n    static mach_timebase_info_data_t timebase_info = {0};\n    if (timebase_info.denom == 0) {\n        mach_timebase_info(&timebase_info);\n    }\n    const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;\n    const uint64_t denominator = timebase_info.denom * UINT64_C(1000);\n    return (elapsed_mach_time * timebase_info.numer + denominator / 2) / denominator;\n}\n\nstatic void print_usage(const char* program_name) {\n    printf(\"Usage: %s <model-path> [-p <prompt>] [-n <tokens>]\\n\", program_name);\n}\n\nstruct options parse_options(int argc, char** argv) {\n    struct options options = (struct options) {\n        .model = NULL,\n        .prompt = NULL,\n        .context_length = 0,\n        .max_tokens = 0,\n        .temperature = 0.0f,\n        .verbose = false,\n    };\n    if (argc < 2) {\n        fprintf(stderr, \"Error: missing required command-line argument\\n\");\n        print_usage(argv[0]);\n        exit(EXIT_FAILURE);\n    }\n    for (int i = 1; i < argc; i++) {\n        if (strcmp(argv[i], \"--help\") == 0) {\n            print_usage(argv[0]);\n            exit(EXIT_SUCCESS);\n        } else if (strcmp(argv[i], \"-p\") == 0 || strcmp(argv[i], \"--prompt\") == 0) {\n            if (i + 1 >= argc) {\n                fprintf(stderr, \"Error: missing argument for %s\\n\", argv[i]);\n                print_usage(argv[0]);\n                exit(EXIT_FAILURE);\n            }\n            options.prompt = argv[++i];\n        } else if (strcmp(argv[i], \"--context-length\") == 0) {\n            if (i + 1 >= argc) {\n                fprintf(stderr, \"Error: missing argument for --context-length\\n\");\n                print_usage(argv[0]);\n                exit(EXIT_FAILURE);\n            }\n            char* context_length_start = argv[++i];\n            char* context_length_end = context_length_start;\n            options.context_length = strtoul(context_length_start, &context_length_end, 10);\n            if (context_length_end == context_length_start || *context_length_end != 0) {\n                fprintf(stderr, \"Error: failed to parse context length value \\\"%s\\\"\\n\", context_length_start);\n                exit(EXIT_FAILURE);\n            }\n        } else if (strcmp(argv[i], \"-n\") == 0 || strcmp(argv[i], \"--max-tokens\") == 0) {\n            if (i + 1 >= argc) {\n                fprintf(stderr, \"Error: missing argument for %s\\n\", argv[i]);\n                print_usage(argv[0]);\n                exit(EXIT_FAILURE);\n            }\n            char* max_tokens_start = argv[++i];\n            char* max_tokens_end = max_tokens_start;\n            options.max_tokens = strtoul(max_tokens_start, &max_tokens_end, 10);\n            if (max_tokens_end == max_tokens_start || *max_tokens_end != 0) {\n                fprintf(stderr, \"Error: failed to max tokens value \\\"%s\\\"\\n\", max_tokens_start);\n                exit(EXIT_FAILURE);\n            }\n            if (options.max_tokens == 0) {\n                fprintf(stderr, \"Error: invalid max tokens value %zu\\n\", options.max_tokens);\n                exit(EXIT_FAILURE);\n            }\n        } else if (strcmp(argv[i], \"-t\") == 0 || strcmp(argv[i], \"--temperature\") == 0) {\n            if (i + 1 >= argc) {\n                fprintf(stderr, \"Error: missing argument for %s\\n\", argv[i]);\n                print_usage(argv[0]);\n                exit(EXIT_FAILURE);\n            }\n            char* temperature_start = argv[++i];\n            char* temperature_end = temperature_start;\n            options.temperature = strtof(temperature_start, &temperature_end);\n            if (temperature_end == temperature_start || *temperature_end != 0) {\n                fprintf(stderr, \"Error: failed to parse temperature value \\\"%s\\\"\\n\", temperature_start);\n                exit(EXIT_FAILURE);\n            }\n            if (signbit(options.temperature) != 0 || !(options.temperature <= 2.0f)) {\n                fprintf(stderr, \"Error: invalid temperature value %f\\n\", options.temperature);\n                exit(EXIT_FAILURE);\n            }\n        } else if (strcmp(argv[i], \"-v\") == 0 || strcmp(argv[i], \"--verbose\") == 0) {\n            options.verbose = true;\n        } else {\n            if (options.model == NULL) {\n                options.model = argv[i];\n            } else {\n                fprintf(stderr, \"Error: unexpected command-line argument %s\\n\", argv[i]);\n                print_usage(argv[0]);\n                exit(EXIT_FAILURE);\n            }\n        }\n    }\n    if (options.model == NULL) {\n        fprintf(stderr, \"Error: missing required model argument\\n\");\n        print_usage(argv[0]);\n        exit(EXIT_FAILURE);\n    }\n    if (options.prompt == NULL) {\n        fprintf(stderr, \"Error: missing required prompt argument\\n\");\n        print_usage(argv[0]);\n        exit(EXIT_FAILURE);\n    }\n    return options;\n}\n\n\nstatic void print_profile() {\n    const size_t num_prefill_tokens = atomic_load(&globals.num_prefill_tokens);\n    const uint64_t prefill_microseconds = atomic_load(&globals.prefill_microseconds);\n    const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens);\n    const uint64_t generation_microseconds = atomic_load(&globals.generation_microseconds);\n    const uint64_t inference_bytes = atomic_load(&globals.inference_bytes);\n    if (num_prefill_tokens != 0 || num_generated_tokens != 0) {\n        printf(\"\\n\");\n    }\n    if (num_prefill_tokens != 0) {\n        printf(\"Prefill speed (%zu tokens): %.1f tokens/second\\n\",\n            num_prefill_tokens,\n            (double) num_prefill_tokens / (double) prefill_microseconds * 1.0e+6);\n    }\n    if (num_generated_tokens != 0) {\n        printf(\"Generation speed (%zu tokens): %.1f tokens/second\\n\",\n            num_generated_tokens,\n            (double) num_generated_tokens / (double) generation_microseconds * 1.0e+6);\n    }\n}\n\nstatic void ctrl_c_handler(int signum) {\n    print_profile();\n    exit(EXIT_SUCCESS);\n}\n\nint main(int argc, char *argv[]) {\n    enum gptoss_status status;\n    gptoss_model_t model = NULL;\n    gptoss_tokenizer_t tokenizer = NULL;\n    gptoss_context_t context = NULL;\n\n    struct sigaction act;\n    act.sa_handler = ctrl_c_handler;\n    sigaction(SIGINT, &act, NULL);\n\n    setvbuf(stdout, NULL, _IONBF, 0);\n\n    struct options options = parse_options(argc, argv);\n\n    const uint64_t load_start_time = mach_continuous_time();\n    status = gptoss_model_create_from_file(options.model, &model);\n    if (status != gptoss_status_success) {\n        fprintf(stderr, \"Error: failed to load model from file %s\\n\", options.model);\n        goto error;\n    }\n    size_t max_model_context_length = 0;\n    status = gptoss_model_get_max_context_length(model, &max_model_context_length);\n    if (status != gptoss_status_success) {\n        fprintf(stderr, \"Error: failed to query maximum context length\\n\");\n        goto error;\n    }\n    assert(max_model_context_length != 0);\n    if (options.context_length == 0) {\n        options.context_length = max_model_context_length;\n    } else if (options.context_length > max_model_context_length) {\n        fprintf(stderr, \"Error: context length %zu exceeds maximum context length %zu supported by the model\\n\", options.context_length, max_model_context_length);\n        goto error;\n    }\n\n    status = gptoss_model_get_tokenizer(model, &tokenizer);\n    if (status != gptoss_status_success) {\n        fprintf(stderr, \"Error: failed to retrieve Tokenizer\\n\");\n        goto error;\n    }\n\n    uint32_t return_token_id = UINT32_MAX;\n    status = gptoss_tokenizer_get_special_token_id(tokenizer, gptoss_special_token_return, &return_token_id);\n    if (status != gptoss_status_success) {\n        fprintf(stderr, \"Error: failed to query end-of-text token ID\\n\");\n        goto error;\n    }\n\n    status = gptoss_context_create(model, options.context_length, /*max_batch_tokens=*/0, &context);\n    if (status != gptoss_status_success) {\n        fprintf(stderr, \"Error: failed to create Context object\\n\");\n        goto error;\n    }\n    if (options.verbose) {\n        printf(\"Model weights size: %.2lf MB\\n\", (double) model->weights_size * 0x1.0p-20);\n        printf(\"Model allocation size: %.2lf MB\\n\", (double) model->allocation_size * 0x1.0p-20);\n        printf(\"Context allocation size: %.2lf MB\\n\", (double) context->allocation_size * 0x1.0p-20);\n        printf(\"  Including KV cache: %.2lf MB\\n\", (double) context->kvcache_size * 0x1.0p-20);\n    }\n\n    const uint64_t load_end_time = mach_continuous_time();\n    const double load_elapsed_seconds = mach_timestamp_diff_to_seconds(load_start_time, load_end_time);\n    if (options.verbose) {\n        printf(\"Loaded model in %.3f seconds\\n\", load_elapsed_seconds);\n    }\n\n    const uint64_t prefill_start_time = mach_continuous_time();\n    size_t num_prefill_tokens = 0;\n    status = gptoss_context_append_chars(context, options.prompt, strlen(options.prompt), &num_prefill_tokens);\n    if (status != gptoss_status_success) {\n        fprintf(stderr, \"Error: failed to tokenize prompt \\\"%s\\\"\\n\", options.prompt);\n        goto error;\n    }\n    atomic_store(&globals.num_prefill_tokens, num_prefill_tokens);\n    status = gptoss_context_process(context);\n    if (status != gptoss_status_success) {\n        fprintf(stderr, \"Error: failed to process Context object\\n\");\n        goto error;\n    }\n    const uint64_t prefill_end_time = mach_continuous_time();\n\n    while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) {\n\n        uint32_t predicted_token = UINT32_MAX;\n        size_t num_predicted_tokens = 0;\n        const uint64_t inference_start_timestamp = mach_continuous_time();\n        status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, /*num_tokens=*/1, &predicted_token, &num_predicted_tokens);\n        if (status != gptoss_status_success) {\n            fprintf(stderr, \"Error: failed to sample from the Context object\\n\");\n            goto error;\n        }\n        const uint64_t inference_end_timestamp = mach_continuous_time();\n\n        if (predicted_token == return_token_id) {\n            // Yield token -> stop generation\n            break;\n        }\n\n        // Unembedding: detokenize\n        size_t token_size = 0;\n        const void* token_ptr = NULL;\n        status = gptoss_tokenizer_decode(tokenizer, predicted_token, &token_ptr, &token_size);\n        if (status != gptoss_status_success) {\n            fprintf(stderr, \"Error: failed to detokenize predicted token %\" PRIu32 \"\\n\", predicted_token);\n            goto error;\n        }\n        const size_t previous_num_generated_tokens = atomic_fetch_add(&globals.num_generated_tokens, 1);\n        if (previous_num_generated_tokens == 0) {\n            atomic_fetch_add(&globals.prefill_microseconds, mach_timestamp_diff_to_microseconds(prefill_start_time, prefill_end_time));\n        } else {\n            atomic_fetch_add(&globals.generation_microseconds, mach_timestamp_diff_to_microseconds(inference_start_timestamp, inference_end_timestamp));\n        }\n        printf(\"%.*s\", (int) token_size, (const char*) token_ptr);\n\n        status = gptoss_context_append_tokens(context, 1, &predicted_token);\n        if (status != gptoss_status_success) {\n            fprintf(stderr, \"Error: failed to append predicted token %\" PRIu32 \" to context\\n\", predicted_token);\n            goto error;\n        }\n    }\n\n    print_profile();\n\n    return EXIT_SUCCESS;\n\nerror:\n    gptoss_context_release(context);\n    gptoss_tokenizer_release(tokenizer);\n    gptoss_model_release(model);\n    return EXIT_FAILURE;\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/datatype.h",
    "content": "#pragma once\n\n#include <stdint.h>\n\n#include <internal/macros.h>\n\n\ntypedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {\n    GPTOSS_ALIGN(2) uint16_t bits;\n} gptoss_bfloat16;\nstatic_assert(sizeof(gptoss_bfloat16) == 2, \"bfloat16 size is not 2 bytes\");\n\n\ntypedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {\n    GPTOSS_ALIGN(2) uint16_t bits;\n} gptoss_float16;\nstatic_assert(sizeof(gptoss_float16) == 2, \"float16 size is not 2 bytes\");\n\n\ntypedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {\n    GPTOSS_ALIGN(1) uint8_t bits;\n} gptoss_float8ue8m0;\nstatic_assert(sizeof(gptoss_float8ue8m0) == 1, \"gptoss_float8ue8m0 size is not 1 bytes\");\n\n\ntypedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {\n    GPTOSS_ALIGN(1) uint8_t bits;\n} gptoss_float8e5m2;\nstatic_assert(sizeof(gptoss_float8e5m2) == 1, \"float8e5m2 size is not 1 bytes\");\n\n\ntypedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {\n    GPTOSS_ALIGN(1) uint8_t bits;\n} gptoss_float8e4m3;\nstatic_assert(sizeof(gptoss_float8e4m3) == 1, \"gptoss_float8e4m3 size is not 1 bytes\");\n\n\ntypedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {\n    GPTOSS_ALIGN(1) uint8_t bits;\n} gptoss_float4e2m1x2;\nstatic_assert(sizeof(gptoss_float4e2m1x2) == 1, \"gptoss_float4e2m1x2 size is not 1 bytes\");\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/datatype.hpp",
    "content": "#pragma once\n\n#include <bit>\n\n#include <internal/datatype.h>\n\n\nnamespace gptoss {\n\ntemplate <typename WideT, typename NarrowT>\nWideT upcast(NarrowT);\n\ntemplate <>\ninline float upcast<float>(gptoss_bfloat16 bf16_value) {\n    const uint32_t bits = static_cast<uint32_t>(bf16_value.bits) << 16;\n    return std::bit_cast<float>(bits);\n}\n\ntemplate <>\ninline float upcast<float>(gptoss_float16 fp16_value) {\n    return static_cast<float>(std::bit_cast<_Float16>(fp16_value.bits));\n}\n\ntemplate <>\ninline float upcast<float>(gptoss_float8e4m3 fp8_value) {\n    static constexpr uint16_t fp8e4m3_to_fp32[256] = {\n        0x0000, 0x3B00, 0x3B80, 0x3BC0, 0x3C00, 0x3C20, 0x3C40, 0x3C60,\n        0x3C80, 0x3C90, 0x3CA0, 0x3CB0, 0x3CC0, 0x3CD0, 0x3CE0, 0x3CF0,\n        0x3D00, 0x3D10, 0x3D20, 0x3D30, 0x3D40, 0x3D50, 0x3D60, 0x3D70,\n        0x3D80, 0x3D90, 0x3DA0, 0x3DB0, 0x3DC0, 0x3DD0, 0x3DE0, 0x3DF0,\n        0x3E00, 0x3E10, 0x3E20, 0x3E30, 0x3E40, 0x3E50, 0x3E60, 0x3E70,\n        0x3E80, 0x3E90, 0x3EA0, 0x3EB0, 0x3EC0, 0x3ED0, 0x3EE0, 0x3EF0,\n        0x3F00, 0x3F10, 0x3F20, 0x3F30, 0x3F40, 0x3F50, 0x3F60, 0x3F70,\n        0x3F80, 0x3F90, 0x3FA0, 0x3FB0, 0x3FC0, 0x3FD0, 0x3FE0, 0x3FF0,\n        0x4000, 0x4010, 0x4020, 0x4030, 0x4040, 0x4050, 0x4060, 0x4070,\n        0x4080, 0x4090, 0x40A0, 0x40B0, 0x40C0, 0x40D0, 0x40E0, 0x40F0,\n        0x4100, 0x4110, 0x4120, 0x4130, 0x4140, 0x4150, 0x4160, 0x4170,\n        0x4180, 0x4190, 0x41A0, 0x41B0, 0x41C0, 0x41D0, 0x41E0, 0x41F0,\n        0x4200, 0x4210, 0x4220, 0x4230, 0x4240, 0x4250, 0x4260, 0x4270,\n        0x4280, 0x4290, 0x42A0, 0x42B0, 0x42C0, 0x42D0, 0x42E0, 0x42F0,\n        0x4300, 0x4310, 0x4320, 0x4330, 0x4340, 0x4350, 0x4360, 0x4370,\n        0x4380, 0x4390, 0x43A0, 0x43B0, 0x43C0, 0x43D0, 0x43E0, 0x7FF0,\n        0x8000, 0xBB00, 0xBB80, 0xBBC0, 0xBC00, 0xBC20, 0xBC40, 0xBC60,\n        0xBC80, 0xBC90, 0xBCA0, 0xBCB0, 0xBCC0, 0xBCD0, 0xBCE0, 0xBCF0,\n        0xBD00, 0xBD10, 0xBD20, 0xBD30, 0xBD40, 0xBD50, 0xBD60, 0xBD70,\n        0xBD80, 0xBD90, 0xBDA0, 0xBDB0, 0xBDC0, 0xBDD0, 0xBDE0, 0xBDF0,\n        0xBE00, 0xBE10, 0xBE20, 0xBE30, 0xBE40, 0xBE50, 0xBE60, 0xBE70,\n        0xBE80, 0xBE90, 0xBEA0, 0xBEB0, 0xBEC0, 0xBED0, 0xBEE0, 0xBEF0,\n        0xBF00, 0xBF10, 0xBF20, 0xBF30, 0xBF40, 0xBF50, 0xBF60, 0xBF70,\n        0xBF80, 0xBF90, 0xBFA0, 0xBFB0, 0xBFC0, 0xBFD0, 0xBFE0, 0xBFF0,\n        0xC000, 0xC010, 0xC020, 0xC030, 0xC040, 0xC050, 0xC060, 0xC070,\n        0xC080, 0xC090, 0xC0A0, 0xC0B0, 0xC0C0, 0xC0D0, 0xC0E0, 0xC0F0,\n        0xC100, 0xC110, 0xC120, 0xC130, 0xC140, 0xC150, 0xC160, 0xC170,\n        0xC180, 0xC190, 0xC1A0, 0xC1B0, 0xC1C0, 0xC1D0, 0xC1E0, 0xC1F0,\n        0xC200, 0xC210, 0xC220, 0xC230, 0xC240, 0xC250, 0xC260, 0xC270,\n        0xC280, 0xC290, 0xC2A0, 0xC2B0, 0xC2C0, 0xC2D0, 0xC2E0, 0xC2F0,\n        0xC300, 0xC310, 0xC320, 0xC330, 0xC340, 0xC350, 0xC360, 0xC370,\n        0xC380, 0xC390, 0xC3A0, 0xC3B0, 0xC3C0, 0xC3D0, 0xC3E0, 0xFFF0,\n    };\n    const gptoss_bfloat16 bf16_value{.bits = fp8e4m3_to_fp32[fp8_value.bits]};\n    return upcast<float>(bf16_value);\n}\n\ntemplate <>\ninline double upcast<double>(float fp32_value) {\n    return static_cast<double>(fp32_value);\n}\n\ntemplate <>\ninline double upcast<double>(gptoss_bfloat16 bf16_value) {\n    const float fp32_value = upcast<float>(bf16_value);\n    return upcast<double>(fp32_value);\n}\n\ntemplate <>\ninline double upcast<double>(gptoss_float16 fp16_value) {\n    const float fp32_value = upcast<float>(fp16_value);\n    return upcast<double>(fp32_value);\n}\n\ntemplate <>\ninline double upcast<double>(gptoss_float8e4m3 fp8_value) {\n    const float fp32_value = upcast<float>(fp8_value);\n    return upcast<double>(fp32_value);\n}\n\n}  // namespace gptoss\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/kernel-args.h",
    "content": "#pragma once\n\n#if !defined(__METAL_VERSION__)\n#include <stdint.h>\n#endif\n\n// TODO(ibahmed): specalize using metal function constants.\n#define QKV_Bm 64\n#define QKV_Bn 64\n#define QKV_Bk 32\n#define QKV_Sg_Bm 32\n#define QKV_Sg_Bn 32\n\n#define ATTN_OUTPUT_Bm 32\n#define ATTN_OUTPUT_Bn 64\n#define ATTN_OUTPUT_Bk 64\n#define ATTN_OUTPUT_Sg_Bm 32\n#define ATTN_OUTPUT_Sg_Bn 16\n\n#define MLP_GATE_Bm 64\n#define MLP_GATE_Bn 16\n#define MLP_GATE_Bk 64\n#define MLP_GATE_Sg_Bm 16\n#define MLP_GATE_Sg_Bn 16\n\n#define MOE_DENSE_MATMUL_SWIGLU_Bm 32\n#define MOE_DENSE_MATMUL_SWIGLU_Bn 64\n#define MOE_DENSE_MATMUL_SWIGLU_Bk 16\n#define MOE_DENSE_MATMUL_SWIGLU_Sg_Bm 32\n#define MOE_DENSE_MATMUL_SWIGLU_Sg_Bn 16\n\n#define MOE_DENSE_MATMUL_Bm 32\n#define MOE_DENSE_MATMUL_Bn 64\n#define MOE_DENSE_MATMUL_Bk 16\n#define MOE_DENSE_MATMUL_Sg_Bm 32\n#define MOE_DENSE_MATMUL_Sg_Bn 16\n\nstruct gptoss_expert_prediction {\n    uint32_t expert_id;\n    float score;\n};\n\nstruct gptoss_control {\n    uint32_t abort;\n};\n\nstruct gptoss_topk_args {\n    uint32_t num_vecs_per_token;\n};\n\nstruct gptoss_sdpa_args {\n    uint32_t qkv_dim;\n    uint32_t num_kv_tokens;\n    uint32_t kv_stride;\n    uint32_t window;\n};\n\nstruct gptoss_u32_fill_random_args {\n    uint64_t num_vecs_per_threadgroup;\n    uint64_t num_vecs;\n    uint64_t offset;\n    uint64_t seed;\n};\n\nstruct gptoss_f32_fill_random_args {\n    uint64_t num_vecs_per_threadgroup;\n    uint64_t num_vecs;\n    uint64_t offset;\n    uint64_t seed;\n    float scale;\n    float bias;\n};\n\nstruct gptoss_accumulate_args {\n    uint32_t num_vecs_per_expert;\n    uint32_t num_vecs_per_threadgroup;\n    uint32_t num_vecs;\n};\n\nstruct gptoss_convert_args {\n    uint64_t num_vecs_per_threadgroup;\n    uint64_t num_vecs;\n};\n\nstruct gptoss_embeddings_args {\n    uint32_t num_vecs;\n};\n\nstruct gptoss_rmsnorm_args {\n    uint32_t num_vecs;\n    float num_channels;\n    float epsilon;\n};\n\nstruct gptoss_matmul_args {\n    uint32_t num_column_vecs;\n    uint32_t num_rows;\n    uint32_t add;\n};\n\nstruct gptoss_dense_matmul_args {\n    uint32_t m;\n    uint32_t n;\n    uint32_t k;\n};\n\n// Specialize qkv matmul args as it writes kv directly to the KV cache buffer.\nstruct gptoss_dense_matmul_qkv_args {\n    uint32_t m;\n    uint32_t n;\n    uint32_t k;\n    uint32_t max_tokens;\n    uint32_t token_offset;\n};\n\nstruct gptoss_scatter_args {\n    uint32_t tokens;\n    uint32_t active_experts_per_token;\n    uint32_t token_stride;\n};\n\nstruct gptoss_moe_dense_matmul_swiglu_args {\n    uint32_t k;\n    uint32_t n;\n    uint32_t weight_blocks_expert_stride_bytes;\n    uint32_t weight_scales_expert_stride_bytes;\n    uint32_t bias_expert_stride_bytes;\n    float swiglu_min;\n    float swiglu_max;\n};\nstruct gptoss_moe_dense_matmul_args {\n    uint32_t k;\n    uint32_t n;\n    uint32_t weight_blocks_expert_stride_bytes;\n    uint32_t weight_scales_expert_stride_bytes;\n    uint32_t bias_expert_stride_bytes;\n};\n\nstruct gptoss_expert_routing_metadata_args {\nuint32_t tokens;\n    uint32_t num_experts;\n};\n\nstruct gptoss_gather_args {\n    uint32_t tokens;\n    uint32_t active_experts_per_token;\n    uint32_t token_stride;\n};\n\nstruct gptoss_unembedding_args {\n    uint32_t num_column_vecs;\n    uint32_t num_rows_per_threadgroup;\n    uint32_t num_rows;\n};\n\nstruct gptoss_moe_matmul_swiglu_args {\n    uint32_t num_column_vecs;\n    uint32_t num_rows;\n    uint32_t num_active_experts;\n    uint32_t weight_expert_stride;  // in bytes\n    uint32_t output_expert_stride;  // in elements\n    float swiglu_min;\n    float swiglu_max;\n};\n\nstruct gptoss_moe_matmul_args {\n    uint32_t num_column_vecs;\n    uint32_t num_rows;\n    uint32_t num_active_experts;\n    uint32_t input_expert_stride;  // in blocks of 32 elements\n    uint32_t weight_expert_stride;  // in bytes\n    uint32_t output_expert_stride;  // in elements\n};\n\nstruct gptoss_rope_args {\n    uint32_t token_stride;\n    uint32_t token_offset;\n    uint32_t max_tokens;\n    float freq_scale;\n    float interpolation_scale;\n    float yarn_offset;\n    float yarn_scale;\n    float yarn_multiplier;\n};\n\nstruct gptoss_qkv_args {\n    uint32_t num_column_vecs;\n    uint32_t num_rows;\n    uint32_t token_offset;\n    float freq_scale;\n    float interpolation_scale;\n    float yarn_offset;\n    float yarn_scale;\n    float yarn_multiplier;\n    uint32_t max_tokens;\n};\n\nstruct gptoss_softmax_args {\n    uint32_t num_vecs;\n    uint32_t num_vecs_per_threadgroup;\n    uint32_t max_threadgroups;\n    float temperature;\n};\n\nstruct gptoss_sample_args {\n    uint64_t rng_seed;\n    uint32_t rng_offset;\n    uint32_t num_blocks;\n    uint32_t num_dims;\n    uint32_t num_dims_per_block;\n};\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/log.h",
    "content": "#pragma once\n\n#include <stdarg.h>\n\n\nvoid gptoss_format_log(const char* format, va_list args);\n\n__attribute__((__format__(__printf__, 1, 2)))\ninline static void gptoss_log(const char* format, ...) {\n    va_list args;\n    va_start(args, format);\n    gptoss_format_log(format, args);\n    va_end(args);\n}\n\n#define GPTOSS_LOG_ERROR(message, ...) \\\n    gptoss_log(\"Error: \" message \"\\n\", ##__VA_ARGS__)\n\n#define GPTOSS_LOG_WARNING(message, ...) \\\n    gptoss_log(\"Warning: \" message \"\\n\", ##__VA_ARGS__)\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/macros.h",
    "content": "#pragma once\n\n/***** Architecture detection macros *****/\n\n#ifdef GPTOSS_ARCH_X86_64\n    #if GPTOSS_ARCH_X86_64 != 0 && GPTOSS_ARCH_X86_64 != 1\n        #error \"Invalid GPTOSS_ARCH_X86_64 value: must be either 0 or 1\"\n    #endif\n#else\n    #if defined(__x86_64__) || defined(_M_X64) && !defined(_M_ARM64EC)\n        #define GPTOSS_ARCH_X86_64 1\n    #else\n        #define GPTOSS_ARCH_X86_64 0\n    #endif\n#endif\n\n#ifdef GPTOSS_ARCH_ARM64\n    #if GPTOSS_ARCH_ARM64 != 0 && GPTOSS_ARCH_ARM64 != 1\n        #error \"Invalid GPTOSS_ARCH_ARM64 value: must be either 0 or 1\"\n    #endif\n#else\n    #if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)\n        #define GPTOSS_ARCH_ARM64 1\n    #else\n        #define GPTOSS_ARCH_ARM64 0\n    #endif\n#endif\n\n#if GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 == 0\n    #error \"Unsupported architecture: neither x86-64 nor ARM64 detected\"\n#elif GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 != 1\n    #error \"Inconsistent architecture detection: both x86-64 and ARM64 detection macros are specified\"\n#endif\n\n/***** Compiler portability macros *****/\n\n#ifndef GPTOSS_LIKELY\n    #if defined(__GNUC__)\n        #define GPTOSS_LIKELY(condition) (__builtin_expect(!!(condition), 1))\n    #else\n        #define GPTOSS_LIKELY(condition) (!!(condition))\n    #endif\n#endif\n\n#ifndef GPTOSS_UNLIKELY\n    #if defined(__GNUC__)\n        #define GPTOSS_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))\n    #else\n        #define GPTOSS_UNLIKELY(condition) (!!(condition))\n    #endif\n#endif\n\n#ifndef GPTOSS_UNPREDICTABLE\n    #if defined(__has_builtin)\n        #if __has_builtin(__builtin_unpredictable)\n            #define GPTOSS_UNPREDICTABLE(condition) (__builtin_unpredictable(!!(condition)))\n        #endif\n    #endif\n#endif\n#ifndef GPTOSS_UNPREDICTABLE\n    #if defined(__GNUC__) && (__GNUC__ >= 9) && !defined(__INTEL_COMPILER)\n        #define GPTOSS_UNPREDICTABLE(condition) (__builtin_expect_with_probability(!!(condition), 0, 0.5))\n    #else\n        #define GPTOSS_UNPREDICTABLE(condition) (!!(condition))\n    #endif\n#endif\n\n// Disable padding for structure members.\n#ifndef GPTOSS_DENSELY_PACKED_STRUCTURE\n    #if defined(__GNUC__)\n        #define GPTOSS_DENSELY_PACKED_STRUCTURE __attribute__((__packed__))\n    #else\n        #error \"Compiler-specific implementation of GPTOSS_DENSELY_PACKED_STRUCTURE required\"\n    #endif\n#endif\n\n#ifndef GPTOSS_ALIGN\n    #if defined(__GNUC__)\n        #define GPTOSS_ALIGN(alignment) __attribute__((__aligned__(alignment)))\n    #elif defined(_MSC_VER)\n        #define GPTOSS_ALIGN(alignment) __declspec(align(alignment))\n    #else\n        #error \"Compiler-specific implementation of GPTOSS_ALIGN required\"\n    #endif\n#endif\n\n#ifndef GPTOSS_FORCE_INLINE\n    #if defined(__GNUC__)\n        #define GPTOSS_FORCE_INLINE inline __attribute__((__always_inline__))\n    #elif defined(_MSC_VER)\n        #define GPTOSS_FORCE_INLINE __forceinline\n    #else\n        #define GPTOSS_FORCE_INLINE inline\n    #endif\n#endif\n\n/***** Symbol visibility macros *****/\n\n#ifndef GPTOSS_INTERNAL_SYMBOL\n    #if defined(__ELF__)\n        #define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__(\"internal\")))\n    #elif defined(__MACH__)\n        #define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__(\"hidden\")))\n    #else\n        #define GPTOSS_INTERNAL_SYMBOL\n    #endif\n#endif\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/math.h",
    "content": "#pragma once\n\n#include <assert.h>\n#include <stddef.h>\n#include <stdint.h>\n\ninline static size_t math_ceil_div(size_t numer, size_t denom) {\n    return (numer + denom - 1) / denom;\n}\n\ninline static size_t math_max(size_t a, size_t b) {\n    return a >= b ? a : b;\n}\n\ninline static size_t math_min(size_t a, size_t b) {\n    return a < b ? a : b;\n}\n\ninline static size_t math_sub_sat(size_t a, size_t b) {\n    return a > b ? a - b : 0;\n}\n\nstatic size_t math_round_down_po2(size_t number, size_t multiple) {\n    assert(multiple != 0);\n    assert((multiple & (multiple - 1)) == 0);\n\n    return number & -multiple;\n}\n\nstatic size_t math_round_up_po2(size_t number, size_t multiple) {\n    assert(multiple != 0);\n    assert((multiple & (multiple - 1)) == 0);\n\n    const size_t multiple_mask = multiple - 1;\n    if ((number & multiple_mask) != 0) {\n        number |= multiple_mask;\n        number += 1;\n    }\n    return number;\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/metal-kernels.h",
    "content": "#pragma once\n\n#include <stddef.h>\n#include <stdint.h>\n\n#include <internal/metal.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n#include <stddef.h>\n#include <stdint.h>\n\n#include <internal/kernel-args.h>\n#include <internal/math.h>\n#include <internal/metal.h>\n\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* u32_fill_random_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint64_t num_elements,\n    uint64_t rng_seed,\n    uint64_t rng_offset);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_fill_random_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint64_t num_elements,\n    uint64_t rng_seed,\n    uint64_t rng_offset,\n    float rng_min,\n    float rng_max);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* bf16_fill_random_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint64_t num_elements,\n    uint64_t rng_seed,\n    uint64_t rng_offset,\n    float rng_min,\n    float rng_max);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* mf4_f32_convert_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* block_buffer,\n    const struct gptoss_metal_buffer* scale_buffer,\n    const struct gptoss_metal_buffer* output_buffer,\n    uint64_t num_elements);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* bf16_f32_embeddings_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* token_buffer,\n    size_t token_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_channels);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_rmsnorm_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_channels,\n    float epsilon);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_matmul_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* kv_buffer,\n    size_t kv_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_q_heads,\n    uint32_t num_kv_heads,\n    uint32_t attn_head_dim,\n    uint32_t token_offset,\n    uint32_t max_tokens,\n    float rope_base,\n    float interpolation_scale,\n    float yarn_offset,\n    float yarn_scale,\n    float yarn_multiplier);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_matmul_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows);\n\nenum gptoss_status\ngptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* kv_buffer,\n    size_t kv_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows,\n    uint32_t max_tokens,\n    uint32_t token_offset);\n\nenum gptoss_status\ngptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows);\n\nenum gptoss_status\ngptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_matmul_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* argmax_buffer,\n    size_t argmax_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_mf4w_moe_matmul_swiglu_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_buffer,\n    size_t expert_offset,\n    const struct gptoss_metal_buffer* weight_block_buffer,\n    size_t weight_block_offset,\n    const struct gptoss_metal_buffer* weight_scale_buffer,\n    size_t weight_scale_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    float swiglu_limit,\n    uint32_t expert_stride,\n    uint32_t num_tokens,\n    uint32_t num_active_experts,\n    uint32_t num_cols,\n    uint32_t num_rows);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_mf4w_moe_matmul_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_buffer,\n    size_t expert_offset,\n    const struct gptoss_metal_buffer* weight_block_buffer,\n    size_t weight_block_offset,\n    const struct gptoss_metal_buffer* weight_scale_buffer,\n    size_t weight_scale_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t expert_stride,\n    uint32_t num_tokens,\n    uint32_t num_active_experts,\n    uint32_t num_cols,\n    uint32_t num_rows);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_rope_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* activations_buffer,\n    size_t activations_offset,\n    const struct gptoss_metal_buffer* kv_buffer,\n    size_t kv_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    float rope_base,\n    float interpolation_scale,\n    float yarn_offset,\n    float yarn_scale,\n    float yarn_multiplier,\n    uint32_t num_tokens,\n    uint32_t num_q_heads,\n    uint32_t num_kv_heads,\n    uint32_t attn_head_dim,\n    uint32_t max_tokens,\n    uint32_t token_offset);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_accumulate_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_buffer,\n    size_t expert_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_channels,\n    uint32_t num_tokens,\n    uint32_t num_experts);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* expert_routing_metadata_fn,\n    const struct gptoss_metal_buffer* expert_predictions_buffer,\n    size_t expert_predictions_offset,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* intra_expert_offsets_buffer,\n    size_t intra_expert_offsets_offset,\n    uint32_t num_tokens,\n    uint32_t num_experts);\n\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_scatter(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_scatter_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_predictions_buffer,\n    size_t expert_predictions_offset,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* intra_expert_offsets_buffer,\n    size_t intra_expert_offsets_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint32_t num_channels,\n    uint32_t num_tokens,\n    uint32_t num_active_experts);\n    \nenum gptoss_status\ngptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_swiglu_fn,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_block_buffer,\n    size_t weight_block_offset,\n    const struct gptoss_metal_buffer* weight_scale_buffer,\n    size_t weight_scale_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    float swiglu_limit,\n    uint32_t expert_stride_bytes,\n    uint32_t num_tokens,\n    uint32_t num_experts,\n    uint32_t num_cols,\n    uint32_t num_rows);\n    \nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_fn,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_block_buffer,\n    size_t weight_block_offset,\n    const struct gptoss_metal_buffer* weight_scale_buffer,\n    size_t weight_scale_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint32_t expert_stride_bytes,\n    uint32_t num_tokens,\n    uint32_t num_experts,\n    uint32_t num_cols,\n    uint32_t num_rows);\n    \nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_gather_and_accumulate_e4_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_predictions_buffer,\n    size_t expert_predictions_offset,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* intra_expert_offsets_buffer,\n    size_t intra_expert_offsets_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint32_t num_channels,\n    uint32_t num_tokens,\n    uint32_t num_active_experts);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_topk_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_experts,\n    uint32_t num_active_experts);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_sdpa_fn,\n    const struct gptoss_metal_buffer* q_buffer,\n    size_t q_offset,\n    const struct gptoss_metal_buffer* kv_buffer,\n    size_t kv_offset,\n    const struct gptoss_metal_buffer* s_buffer,\n    size_t s_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t window,\n    uint32_t kv_stride,\n    uint32_t num_q_tokens,\n    uint32_t num_kv_tokens,\n    uint32_t num_q_heads,\n    uint32_t num_kv_heads,\n    uint32_t head_dim);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_softmax_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* score_buffer,\n    size_t score_offset,\n    const struct gptoss_metal_buffer* argmax_buffer,\n    size_t argmax_offset,\n    const struct gptoss_metal_buffer* prob_buffer,\n    size_t prob_offset,\n    const struct gptoss_metal_buffer* sum_buffer,\n    size_t sum_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_channels,\n    uint32_t num_tokens,\n    float temperature,\n    uint32_t* num_threadgroups_out,\n    uint32_t* num_channels_per_threadgroup_out);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_sample_fn,\n    size_t min_threadgroup_size,\n    const struct gptoss_metal_buffer* prob_buffer,\n    size_t prob_offset,\n    const struct gptoss_metal_buffer* sum_buffer,\n    size_t sum_offset,\n    const struct gptoss_metal_buffer* token_buffer,\n    size_t token_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint64_t rng_seed,\n    uint32_t rng_offset,\n    uint32_t num_blocks,\n    uint32_t num_channels,\n    uint32_t num_channels_per_block);\n\n#ifdef __cplusplus\n}  // extern \"C\"\n#endif\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/metal.h",
    "content": "#pragma once\n\n#include <stddef.h>\n\n#include <gpt-oss/types.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nstruct gptoss_metal_device {\n    void* object; // id<MTLDevice>\n    size_t num_cores;\n    size_t max_buffer_size;\n    size_t max_threadgroup_memory;\n    size_t max_threadgroup_threads_x;\n    size_t max_threadgroup_threads_y;\n    size_t max_threadgroup_threads_z;\n};\n\nenum gptoss_status gptoss_metal_device_create_system_default(\n    struct gptoss_metal_device* device_out);\n\nenum gptoss_status gptoss_metal_device_release(\n    struct gptoss_metal_device* device);\n\n\nstruct gptoss_metal_library {\n    void* object; // id<MTLLibrary>\n};\n\nenum gptoss_status gptoss_metal_library_create_default(\n    const struct gptoss_metal_device* device,\n    struct gptoss_metal_library* library_out);\n\nenum gptoss_status gptoss_metal_library_release(\n    struct gptoss_metal_library* library);\n\nstruct gptoss_metal_function {\n    void* function_object; // id<MTLFunction>\n    void* pipeline_state_object; // id<MTLComputePipelineState>\n    size_t max_threadgroup_threads;\n    size_t simdgroup_threads;\n    size_t static_threadgroup_memory;\n};\n\nenum gptoss_status gptoss_metal_function_create(\n    const struct gptoss_metal_library* library,\n    const char* name,\n    struct gptoss_metal_function* function_out);\n\nenum gptoss_status gptoss_metal_function_release(\n    struct gptoss_metal_function* function);\n\nstruct gptoss_metal_buffer {\n    void* object; // id<MTLBuffer>\n    size_t size;\n    void* ptr;\n};\n\nenum gptoss_status gptoss_metal_buffer_create(\n    const struct gptoss_metal_device* device,\n    size_t size,\n    const void* data,\n    struct gptoss_metal_buffer* buffer_out);\n\nenum gptoss_status gptoss_metal_buffer_wrap(\n    const struct gptoss_metal_device* device,\n    size_t size,\n    const void* data,\n    struct gptoss_metal_buffer* buffer_out);\n\nenum gptoss_status gptoss_metal_buffer_release(\n    struct gptoss_metal_buffer* buffer);\n\nstruct gptoss_metal_command_queue {\n    void* object; // id<MTLCommandQueue>\n};\n\nenum gptoss_status gptoss_metal_command_queue_create(\n    const struct gptoss_metal_device* device,\n    struct gptoss_metal_command_queue* command_queue_out);\n\nenum gptoss_status gptoss_metal_command_queue_release(\n    struct gptoss_metal_command_queue* command_queue);\n\nstruct gptoss_metal_command_buffer {\n    void* object; // id<MTLCommandBuffer>\n};\n\nenum gptoss_status gptoss_metal_command_buffer_create(\n    const struct gptoss_metal_command_queue* command_queue,\n    struct gptoss_metal_command_buffer* command_buffer_out);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_fill_buffer(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_buffer* buffer,\n    size_t offset,\n    size_t size,\n    uint8_t fill_value);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_copy_buffer(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    size_t size);\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* function,\n    size_t threadgroup_size_x,\n    size_t threadgroup_size_y,\n    size_t threadgroup_size_z,\n    size_t num_threadgroups_x,\n    size_t num_threadgroups_y,\n    size_t num_threadgroups_z,\n    size_t params_size,\n    const void* params,\n    size_t num_device_buffers,\n    const struct gptoss_metal_buffer** device_buffers,\n    const size_t* device_buffer_offsets,\n    size_t threadgroup_buffer_size);\n\nenum gptoss_status gptoss_metal_command_buffer_commit(\n    const struct gptoss_metal_command_buffer* command_buffer);\n\nenum gptoss_status gptoss_metal_command_buffer_wait_completion(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    double* elapsed_seconds);\n\nenum gptoss_status gptoss_metal_command_buffer_release(\n    struct gptoss_metal_command_buffer* command_buffer);\n\n#ifdef __cplusplus\n}  // extern \"C\"\n#endif\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/metal.hpp",
    "content": "#pragma once\n\n#include <array>\n#include <initializer_list>\n#include <cstring>\n#include <stdexcept>\n#include <vector>\n\n#include <gpt-oss/types.h>\n#include <internal/metal.h>\n#include <internal/metal-kernels.h>\n\n\nnamespace gptoss {\n\ninline void Check(gptoss_status s, const char* what) {\n    if (s != gptoss_status_success) {\n        throw std::runtime_error(what);\n    }\n}\n\ninline std::size_t round_up(std::size_t p, std::size_t q) {\n    const std::size_t r = p % q;\n    if (r == 0) {\n        return p;\n    } else {\n        return p - r + q;\n    }\n}\n\nnamespace metal {\n\nclass Device {\npublic:\n    inline Device() {\n        Check(gptoss_metal_device_create_system_default(&device_), \"create Device\");\n    }\n\n    inline ~Device() {\n        gptoss_metal_device_release(&device_);\n    }\n\n    Device(const Device&) = delete;\n    Device& operator=(const Device&) = delete;\n\n    inline Device(Device&& other) noexcept {\n        device_ = other.device_;\n        std::memset(&other.device_, 0, sizeof(other.device_));\n    }\n\n    inline Device& operator=(Device&& other) noexcept {\n        if (this != &other) {\n            gptoss_metal_device_release(&device_);\n            device_ = other.device_;\n            std::memset(&other.device_, 0, sizeof(other.device_));\n        }\n        return *this;\n    }\n\n    inline const gptoss_metal_device* handle() const noexcept { return &device_; }\n\n    inline size_t max_buffer_size() const noexcept { return device_.max_buffer_size; }\n    inline size_t max_threadgroup_memory() const noexcept { return device_.max_threadgroup_memory; }\n    inline size_t max_threadgroup_threads_x() const noexcept { return device_.max_threadgroup_threads_x; }\n    inline size_t max_threadgroup_threads_y() const noexcept { return device_.max_threadgroup_threads_y; }\n    inline size_t max_threadgroup_threads_z() const noexcept { return device_.max_threadgroup_threads_z; }\n\nprivate:\n    gptoss_metal_device device_{};\n};\n\nclass Library {\npublic:\n    inline explicit Library(const Device& dev) {\n        Check(gptoss_metal_library_create_default(dev.handle(), &library_),\n            \"gptoss_metal_library_create_default\");\n    }\n\n    inline ~Library() {\n        gptoss_metal_library_release(&library_);\n    }\n\n    Library(const Library&) = delete;\n    Library& operator=(const Library&) = delete;\n\n    inline Library(Library&& other) noexcept {\n        library_ = other.library_;\n        std::memset(&other.library_, 0, sizeof(other.library_));\n    }\n\n    inline Library& operator=(Library&& other) noexcept {\n        if (this != &other) {\n            gptoss_metal_library_release(&library_);\n            library_ = other.library_;\n            std::memset(&other.library_, 0, sizeof(other.library_));\n        }\n        return *this;\n    }\n\n    inline const gptoss_metal_library* handle() const noexcept {\n        return &library_;\n    }\n\nprivate:\n    gptoss_metal_library library_{};\n};\n\nclass Function {\npublic:\n    inline Function(const Library& library, const char* name) {\n        Check(gptoss_metal_function_create(library.handle(), name, &function_),\n            \"gptoss_metal_function_create\");\n    }\n\n    inline ~Function() {\n        gptoss_metal_function_release(&function_);\n    }\n\n    Function(const Function&) = delete;\n    Function& operator=(const Function&) = delete;\n\n    inline Function(Function&& other) noexcept {\n        function_ = other.function_;\n        std::memset(&other.function_, 0, sizeof(other.function_));\n    }\n\n    inline Function& operator=(Function&& other) noexcept {\n        if (this != &other) {\n            gptoss_metal_function_release(&function_);\n            function_ = other.function_;\n            std::memset(&other.function_, 0, sizeof(other.function_));\n        }\n        return *this;\n    }\n\n    inline const gptoss_metal_function* handle() const noexcept { return &function_; }\n\n    inline size_t max_threadgroup_threads() const noexcept { return function_.max_threadgroup_threads; }\n    inline size_t simdgroup_threads() const noexcept { return function_.simdgroup_threads; }\n    inline size_t static_threadgroup_memory() const noexcept { return function_.static_threadgroup_memory; }\n\nprivate:\n    gptoss_metal_function function_{};\n};\n\nclass Buffer {\npublic:\n    inline Buffer(const Device& dev, size_t size, const void* data = nullptr) {\n        Check(gptoss_metal_buffer_create(dev.handle(), size, data, &buffer_), \"create buffer\");\n    }\n\n    inline ~Buffer() {\n        gptoss_metal_buffer_release(&buffer_);\n    }\n\n    Buffer(const Buffer&) = delete;\n    Buffer& operator=(const Buffer&) = delete;\n\n    inline Buffer(Buffer&& other) noexcept {\n        buffer_ = other.buffer_;\n        std::memset(&other.buffer_, 0, sizeof(other.buffer_));\n    }\n\n    inline Buffer& operator=(Buffer&& other) noexcept {\n        if (this != &other) {\n            gptoss_metal_buffer_release(&buffer_);\n            buffer_ = other.buffer_;\n            std::memset(&other.buffer_, 0, sizeof(other.buffer_));\n        }\n        return *this;\n    }\n\n    inline size_t size() const noexcept { return buffer_.size; }\n    inline void* ptr() const noexcept { return buffer_.ptr; }\n\n    inline const gptoss_metal_buffer* handle() const noexcept { return &buffer_; }\n\nprivate:\n    gptoss_metal_buffer buffer_{};\n};\n\nclass CommandQueue {\npublic:\n    inline explicit CommandQueue(const Device& dev) {\n        Check(gptoss_metal_command_queue_create(dev.handle(), &command_queue_),\n            \"gptoss_metal_command_queue_create\");\n    }\n\n    inline ~CommandQueue() {\n        gptoss_metal_command_queue_release(&command_queue_);\n    }\n\n    CommandQueue(const CommandQueue&) = delete;\n    CommandQueue& operator=(const CommandQueue&) = delete;\n\n    inline CommandQueue(CommandQueue&& other) noexcept {\n        command_queue_ = other.command_queue_;\n        std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));\n    }\n\n    inline CommandQueue& operator=(CommandQueue&& other) noexcept {\n        if (this != &other) {\n            gptoss_metal_command_queue_release(&command_queue_);\n            command_queue_ = other.command_queue_;\n            std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));\n        }\n        return *this;\n    }\n\n    inline const gptoss_metal_command_queue* handle() const noexcept {\n        return &command_queue_;\n    }\n\nprivate:\n    gptoss_metal_command_queue command_queue_{};\n};\n\nclass CommandBuffer {\npublic:\n    inline explicit CommandBuffer(const CommandQueue& command_queue) {\n        Check(gptoss_metal_command_buffer_create(command_queue.handle(), &command_buffer_),\n            \"gptoss_metal_command_buffer_create\");\n    }\n    inline ~CommandBuffer() {\n        gptoss_metal_command_buffer_release(&command_buffer_);\n    }\n\n    CommandBuffer(const CommandBuffer&)            = delete;\n    CommandBuffer& operator=(const CommandBuffer&) = delete;\n\n    inline CommandBuffer(CommandBuffer&& other) noexcept  {\n        command_buffer_ = other.command_buffer_;\n        std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));\n    }\n\n    inline CommandBuffer& operator=(CommandBuffer&& other) noexcept {\n        if (this != &other) {\n            gptoss_metal_command_buffer_release(&command_buffer_);\n            command_buffer_ = other.command_buffer_;\n            std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));\n        }\n        return *this;\n    }\n\n    inline void encode_launch_kernel(const Function& function,\n                                     const std::array<size_t, 3>& threadgroup_size,\n                                     const std::array<size_t, 3>& num_threadgroups,\n                                     size_t params_size, const void* params,\n                                     std::initializer_list<const Buffer*> device_buffers = {},\n                                     size_t threadgroup_buffer_size = 0)\n    {\n        std::vector<const gptoss_metal_buffer*> buffer_handles(device_buffers.size());\n        std::transform(device_buffers.begin(), device_buffers.end(), buffer_handles.begin(),\n            [](const Buffer* buffer) -> const gptoss_metal_buffer* { return buffer->handle(); });\n        Check(gptoss_metal_command_buffer_encode_launch_kernel(\n                &command_buffer_, function.handle(),\n                threadgroup_size[0], threadgroup_size[1], threadgroup_size[2],\n                num_threadgroups[0], num_threadgroups[1], num_threadgroups[2],\n                params_size, params,\n                buffer_handles.size(),\n                buffer_handles.data(),\n                /*buffer_offsets=*/nullptr,\n                threadgroup_buffer_size),\n            \"gptoss_metal_command_buffer_encode_launch_kernel\");\n    }\n\n    inline void encode_launch_f32_fill_random(const Function& f32_fill_random_fn,\n                                              size_t threadgroup_size,\n                                              size_t num_threadgroups,\n                                              const Buffer& output_buffer,\n                                              size_t output_offset,\n                                              size_t num_channels,\n                                              uint64_t rng_seed,\n                                              uint64_t rng_offset,\n                                              float rng_min,\n                                              float rng_max)\n    {\n        Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n                &command_buffer_, f32_fill_random_fn.handle(),\n                threadgroup_size, num_threadgroups,\n                output_buffer.handle(), output_offset,\n                num_channels,\n                rng_seed, rng_offset, rng_min, rng_max),\n            \"gptoss_metal_command_buffer_encode_launch_f32_fill_random\");\n    }\n\n    inline void encode_launch_bf16_fill_random(const Function& bf16_fill_random_fn,\n                                               size_t threadgroup_size,\n                                               size_t num_threadgroups,\n                                               const Buffer& output_buffer,\n                                               size_t output_offset,\n                                               size_t num_channels,\n                                               uint64_t rng_seed,\n                                               uint64_t rng_offset,\n                                               float rng_min,\n                                               float rng_max)\n    {\n        Check(gptoss_metal_command_buffer_encode_launch_bf16_fill_random(\n                &command_buffer_, bf16_fill_random_fn.handle(),\n                threadgroup_size, num_threadgroups,\n                output_buffer.handle(), output_offset,\n                num_channels,\n                rng_seed, rng_offset, rng_min, rng_max),\n            \"gptoss_metal_command_buffer_encode_launch_bf16_fill_random\");\n    }\n\n    inline void encode_launch_u32_fill_random(const Function& u32_fill_random_fn,\n                                              size_t threadgroup_size,\n                                              size_t num_threadgroups,\n                                              const Buffer& output_buffer,\n                                              size_t output_offset,\n                                              size_t num_channels,\n                                              uint64_t rng_seed,\n                                              uint64_t rng_offset)\n    {\n        Check(gptoss_metal_command_buffer_encode_launch_u32_fill_random(\n                &command_buffer_, u32_fill_random_fn.handle(),\n                threadgroup_size, num_threadgroups,\n                output_buffer.handle(), output_offset,\n                num_channels,\n                rng_seed, rng_offset),\n            \"gptoss_metal_command_buffer_encode_launch_u32_fill_random\");\n    }\n\n    inline void commit() {\n        Check(gptoss_metal_command_buffer_commit(&command_buffer_), \"commit\");\n    }\n\n    inline double wait_completion() {\n        double secs = 0.0;\n        Check(gptoss_metal_command_buffer_wait_completion(&command_buffer_, &secs), \"wait completion\");\n        return secs;\n    }\n\n    inline const gptoss_metal_command_buffer* handle() const noexcept { return &command_buffer_; }\n\nprivate:\n    gptoss_metal_command_buffer command_buffer_{};\n};\n\n} // namespace metal\n} // namespace gptoss\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/model.h",
    "content": "#pragma once\n\n#ifndef __cplusplus\n    #include <stdatomic.h>\n#endif\n#include <stdbool.h>\n#include <stddef.h>\n#include <stdint.h>\n\n#include \"internal/metal.h\"\n\n\nstruct gptoss_tokenizer {\n#ifndef __cplusplus\n    atomic_uint_least64_t ref_count;\n#else\n    uint_least64_t ref_count;\n#endif\n\n    void* mapping_ptr;\n    size_t mapping_size;\n\n    const char* regex_ptr;\n    const char* tokens_ptr;\n\n    uint32_t num_text_tokens;\n    uint32_t num_special_tokens;\n\n    uint32_t special_token_id[gptoss_special_token_max - 1];\n};\n\nstruct gptoss_model {\n#ifndef __cplusplus\n    atomic_uint_least64_t ref_count;\n#else\n    uint_least64_t ref_count;\n#endif\n\n    struct gptoss_tokenizer* tokenizer;\n\n    void* mapping_ptr;\n    size_t mapping_size;\n\n    uint32_t context_length;\n    uint32_t num_blocks;\n    uint32_t num_experts;\n    uint32_t num_active_experts;\n    uint32_t embedding_dim;\n    uint32_t mlp_dim;\n    float swiglu_limit;\n    uint32_t head_dim;\n    uint32_t num_heads;\n    uint32_t num_kv_heads;\n    uint32_t attention_window;\n    float rope_theta;\n    float interpolation_scale;\n    float yarn_offset;\n    float yarn_scale;\n    float yarn_multiplier;\n    float rmsnorm_epsilon;\n\n    uint32_t vocabulary_size;\n\n    bool lock_memory;\n\n    size_t weights_size;\n    size_t allocation_size;\n\n    // Metal objects\n    struct gptoss_metal_device device;\n    size_t max_threadgroups;\n    struct gptoss_metal_command_queue command_queue;\n    struct gptoss_metal_library library;\n    struct gptoss_metal_function bf16_f32_embeddings_fn;\n    struct gptoss_metal_function f32_bf16w_rmsnorm_fn;\n    struct gptoss_metal_function f32_bf16w_matmul_fn;\n    struct gptoss_metal_function f32_bf16w_matmul_qkv_fn;\n    struct gptoss_metal_function f32_bf16w_dense_matmul_qkv_fn;\n    struct gptoss_metal_function f32_bf16w_dense_matmul_attn_output_fn;\n    struct gptoss_metal_function f32_bf16w_dense_matmul_mlp_gate_fn;\n    struct gptoss_metal_function f32_bf16w_unembedding_fn;\n    struct gptoss_metal_function f32_rope_fn;\n    struct gptoss_metal_function f32_mf4w_moe_matmul_swiglu_fn;\n    struct gptoss_metal_function f32_mf4w_moe_matmul_fn;\n    struct gptoss_metal_function f32_accumulate_e4_fn;\n    struct gptoss_metal_function f32_scatter_e4_fn;\n    struct gptoss_metal_function f32_mf4w_moe_dense_matmul_swiglu_fn;\n    struct gptoss_metal_function f32_mf4w_moe_dense_matmul_fn;\n    struct gptoss_metal_function f32_gather_and_accumulate_e4_fn;\n    struct gptoss_metal_function f32_expert_routing_metadata_fn;\n    struct gptoss_metal_function f32_topk_softmax_e32_k4_fn;\n    struct gptoss_metal_function f32_topk_softmax_e128_k4_fn;\n    struct gptoss_metal_function f32_sdpa_q8_d64_fn;\n    struct gptoss_metal_function f32_softmax_fn;\n    struct gptoss_metal_function f32_sample_fn;\n\n    size_t per_block_shared_weights_size;\n    size_t per_expert_block_weight_size;\n\n    size_t embeddings_threadgroup_size;\n    size_t attn_qkv_threadgroup_size;\n    size_t attn_out_threadgroup_size;\n    size_t mlp_gate_threadgroup_size;\n    size_t mlp_swiglu_threadgroup_size;\n    size_t mlp_out_threadgroup_size;\n    size_t mlp_acc_threadgroup_size;\n    size_t unembedding_threadgroup_size;\n\n    size_t attn_rmsnorm_gain_offset;\n    size_t attn_qkv_weight_offset;\n    size_t attn_qkv_bias_offset;\n    size_t attn_sdpa_sink_offset;\n    size_t attn_out_weight_offset;\n    size_t attn_out_bias_offset;\n    size_t mlp_rmsnorm_gain_offset;\n    size_t mlp_gate_weight_offset;\n    size_t mlp_gate_bias_offset;\n    size_t mlp_swiglu_scale_offset;\n    size_t mlp_swiglu_bias_offset;\n    size_t mlp_out_block_offset;\n    size_t mlp_out_scale_offset;\n    size_t mlp_out_bias_offset;\n    size_t rmsnorm_weight_offset;\n    size_t unembedding_weight_offset;\n\n    // Buffer with non-MoE weights. Includes MoE gates, embeddings/unembeddings.\n    struct gptoss_metal_buffer shared_weight_buffer;\n    // num_blocks per-block buffers with MoE weights to follow.\n    struct gptoss_metal_buffer block_weight_buffers[];\n};\n\n#define GPTOSS_DEFAULT_BATCH_SIZE 128\n\nstruct gptoss_context {\n#ifndef __cplusplus\n    atomic_uint_least64_t ref_count;\n#else\n    uint_least64_t ref_count;\n#endif\n\n    struct gptoss_model* model;\n    // Number of tokens processed in the context.\n    size_t num_tokens;\n    // Number of tokens in the KV cache.\n    size_t num_kv_tokens;\n    // Length of the context.\n    size_t max_tokens;\n    // Maximum number of tokens that can be processed in a single batch.\n    // Activation buffers are allocated with this size.\n    size_t max_batch_tokens;\n\n\n    size_t kvcache_size;\n    size_t allocation_size;\n\n    // Activation buffers.\n    // TODO: merge into a single buffer.\n    struct gptoss_metal_buffer residual_activation_buffer;  // Residual stream\n    struct gptoss_metal_buffer rmsnorm_activation_buffer;  // Both attention & MLP RMSNorm output\n    struct gptoss_metal_buffer qkv_activation_buffer;  // QKV projection output\n    struct gptoss_metal_buffer sdpa_activation_buffer;  // SDPA output\n    struct gptoss_metal_buffer gate_activation_buffer;  // MoE gating output\n    struct gptoss_metal_buffer expert_activation_buffer;  // MoE expert predictions\n    struct gptoss_metal_buffer expert_offset_buffer; // MoE expert histograms cumsum\n    struct gptoss_metal_buffer token_to_expert_routing_buffer; // MoE token to expert routing\n    struct gptoss_metal_buffer swiglu_input_buffer; // MLP+SwiGLU input for prefill.\n    struct gptoss_metal_buffer swiglu_activation_buffer;  // MLP+SwiGLU output\n    struct gptoss_metal_buffer moe_activation_buffer;  // MoE MLP output (per-active expert)\n\n    // Input/output buffers.\n    struct gptoss_metal_buffer control_buffer;\n    struct gptoss_metal_buffer token_buffer;  // uint32 token IDs\n    struct gptoss_metal_buffer score_buffer;  // unembedding outputs\n    struct gptoss_metal_buffer prob_buffer;\n    struct gptoss_metal_buffer sum_buffer;\n    struct gptoss_metal_buffer argmax_buffer;\n    struct gptoss_metal_buffer kvcache_buffer;\n};\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/rng.h",
    "content": "#pragma once\n\n#include <stdint.h>\n\ninline static uint32_t rng_squares32(uint64_t offset, uint64_t seed) {\n    const uint64_t y = offset * seed;\n    const uint64_t z = y + seed;\n\n    /* Round 1 */\n    uint64_t x = y * y + y;\n    x = (x >> 32) | (x << 32);\n\n    /* Round 2 */\n    x = x * x + z;\n    x = (x >> 32) | (x << 32);\n\n    /* Round 3 */\n    x = x * x + y;\n    x = (x >> 32) | (x << 32);\n\n    /* Round 4 */\n    x = x * x + z;\n    return (uint32_t) (x >> 32);\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/rng.hpp",
    "content": "#pragma once\n\n#include <cstdint>\n\nnamespace gptoss {\n\nnamespace rng {\n\ninline static std::uint32_t squares32(std::uint64_t offset, std::uint64_t seed) {\n    const std::uint64_t y = offset * seed;\n    const std::uint64_t z = y + seed;\n\n    /* Round 1 */\n    std::uint64_t x = y * y + y;\n    x = (x >> 32) | (x << 32);\n\n    /* Round 2 */\n    x = x * x + z;\n    x = (x >> 32) | (x << 32);\n\n    /* Round 3 */\n    x = x * x + y;\n    x = (x >> 32) | (x << 32);\n\n    /* Round 4 */\n    x = x * x + z;\n    return static_cast<uint32_t>(x >> 32);\n}\n\n}  // namespace rng\n\n}  // namespace gptoss\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/storage.h",
    "content": "#pragma once\n\n#include <stdbool.h>\n#include <stdint.h>\n\nstruct gptoss_file_header {\n    char magic[12];\n    uint32_t zero;\n};\n\nstruct gptoss_gptoss_model_header {\n    uint32_t context_length;\n    uint32_t num_blocks;\n    uint32_t num_experts;\n    uint32_t num_active_experts;\n    uint32_t embedding_dim;\n    uint32_t mlp_dim;\n    float swiglu_limit;\n    uint32_t head_dim;\n    uint32_t num_heads;\n    uint32_t num_kv_heads;\n    uint32_t attention_window;\n    float rope_theta;\n    float interpolation_scale;\n    float yarn_offset;\n    float yarn_scale;\n    float yarn_multiplier;\n    float rmsnorm_epsilon;\n};\n\nstruct gptoss_tiktoken_tokenizer_header {\n    uint32_t num_special_tokens;\n    uint32_t num_text_tokens;\n    uint32_t regex_size;\n    uint32_t tokens_size;\n};\n"
  },
  {
    "path": "gpt_oss/metal/source/include/internal/uuid.h",
    "content": "#pragma once\n\n#include <stdbool.h>\n#include <stdint.h>\n#include <string.h>\n\n#include \"internal/macros.h\"\n\n\nstruct GPTOSS_DENSELY_PACKED_STRUCTURE gptoss_uuid {\n    uint8_t bytes[16];\n};\nstatic_assert(sizeof(struct gptoss_uuid) == 16, \"UUID size is not 16 bytes\");\n\n\n#define UUID_FORMAT \"%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X\"\n#define UUID_ARGS(uuid) (uuid).bytes[0], (uuid).bytes[1], (uuid).bytes[2], (uuid).bytes[3], \\\n    (uuid).bytes[4], (uuid).bytes[5], (uuid).bytes[6], (uuid).bytes[7], (uuid).bytes[8], (uuid).bytes[9], \\\n    (uuid).bytes[10], (uuid).bytes[11], (uuid).bytes[12], (uuid).bytes[13], (uuid).bytes[14], (uuid).bytes[15]\n\nstatic inline bool gptoss_is_gptoss_model_uuid(const struct gptoss_uuid* uuid) {\n    return memcmp(\n        &(struct gptoss_uuid) {0xDF, 0x52, 0xDC, 0x86, 0x17, 0x89, 0x4E, 0xD0, 0xA2, 0x95, 0x66, 0xF1, 0x05, 0x08, 0x14, 0x5B},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0;\n}\n\nstatic inline bool gptoss_is_applegpu_layout_uuid(const struct gptoss_uuid* uuid) {\n    return memcmp(\n        &(struct gptoss_uuid) {0x22, 0x91, 0x77, 0xA8, 0x57, 0x75, 0x42, 0x68, 0xBF, 0xD8, 0xD5, 0x88, 0xB3, 0x51, 0xC5, 0x6D},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0;\n}\n\nstatic inline bool gptoss_is_tiktoken_tokenizer_uuid(const struct gptoss_uuid* uuid) {\n    return memcmp(\n        &(struct gptoss_uuid) {0x74, 0x01, 0xAD, 0xED, 0x2A, 0x95, 0x40, 0xCB, 0xB7, 0x82, 0x9C, 0xCE, 0xBA, 0xAF, 0xE7, 0x2B},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0;\n}\n\nstatic inline enum gptoss_special_token gptoss_special_token_decode_uuid(const struct gptoss_uuid* uuid) {\n    if (memcmp(\n        &(struct gptoss_uuid) {0x55, 0xA7, 0x7C, 0x2F, 0x8A, 0x01, 0x4C, 0x54, 0x8A, 0xC2, 0x31, 0x3B, 0xFC, 0x7E, 0x20, 0x8D},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_start;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0x16, 0xE4, 0x04, 0x31, 0xF4, 0x7F, 0x4B, 0x22, 0xB5, 0x9B, 0x8B, 0x27, 0x8F, 0xC3, 0x0A, 0x54},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_message;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0xFC, 0xAC, 0x2F, 0x6D, 0x47, 0x05, 0x4F, 0x6B, 0xB2, 0x28, 0x64, 0x2A, 0xCC, 0xAC, 0x72, 0x38},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_end;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0xF7, 0x99, 0xFF, 0x69, 0x19, 0x92, 0x43, 0xC4, 0xA3, 0xD8, 0xD8, 0x31, 0xF4, 0x75, 0xDC, 0x75},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_return;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0xE1, 0x5B, 0xA7, 0x02, 0x28, 0xC4, 0x42, 0x92, 0xAB, 0x8F, 0xFF, 0xA4, 0x34, 0x70, 0x91, 0x28},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_refusal;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0xC0, 0xBB, 0x14, 0xC7, 0x60, 0x22, 0x49, 0xDA, 0xAD, 0x08, 0x79, 0x2D, 0x67, 0xE8, 0xB4, 0x70},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_constrain;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0xFD, 0x3D, 0xDA, 0x11, 0xC8, 0xAB, 0x40, 0x33, 0x87, 0x6E, 0xD9, 0x3D, 0xEB, 0x17, 0x2C, 0x93},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_channel;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0x12, 0x20, 0xF7, 0x96, 0xE3, 0x88, 0x4D, 0xE5, 0xB4, 0x87, 0xFE, 0x2E, 0xB5, 0xFE, 0x03, 0xC0},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_call;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0x07, 0xD7, 0xDA, 0x55, 0xB3, 0x46, 0x4C, 0xFF, 0x8B, 0x37, 0x7C, 0xEF, 0xAC, 0xF8, 0xA3, 0xE8},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_untrusted;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0xF2, 0x65, 0xBD, 0x9C, 0xC7, 0x17, 0x46, 0x9E, 0xA4, 0x47, 0x92, 0x06, 0x87, 0xD6, 0x5D, 0x90},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        return gptoss_special_token_end_untrusted;\n    } else if (memcmp(\n        &(struct gptoss_uuid) {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},\n        uuid,\n        sizeof(struct gptoss_uuid)) == 0)\n    {\n        // Suppress warning\n        return gptoss_special_token_invalid;\n    } else {\n        GPTOSS_LOG_WARNING(\"unsupported special token \" UUID_FORMAT, UUID_ARGS(*uuid));\n        return gptoss_special_token_invalid;\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/log.c",
    "content": "#include <assert.h>  // assert\n#include <stdarg.h>  // va_list, va_copy, va_end\n#include <stdio.h>  // vsnprintf\n#include <stdlib.h>  // malloc, free\n\n#include <unistd.h>  // STDERR_FILENO\n\n\n\n#define GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE 16384\n\nvoid gptoss_format_log(const char* format, va_list args) {\n    char stack_buffer[GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE];\n    char* heap_buffer = NULL;\n\n    va_list args_copy;\n    va_copy(args_copy, args);\n\n    const int vsnprintf_result = vsnprintf(stack_buffer, GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE, format, args);\n    assert(vsnprintf_result >= 0);\n\n    // At least a partially formatted buffer is ready.\n    char* message_buffer = &stack_buffer[0];\n    size_t message_size = (size_t) vsnprintf_result;\n    if (message_size > GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE) {\n        heap_buffer = malloc(message_size);\n        if (heap_buffer == NULL) {\n            // Fall back to the truncated message in the on-stack buffer.\n            message_size = GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE;\n        } else {\n            // Use the full message in the in-heap buffer.\n            vsnprintf(heap_buffer, message_size, format, args_copy);\n            message_buffer = heap_buffer;\n        }\n    }\n\n    ssize_t bytes_written;\n    do {\n        bytes_written = write(STDERR_FILENO, message_buffer, message_size);\n        if (bytes_written > 0) {\n            assert((size_t) bytes_written <= message_size);\n            message_buffer += bytes_written;\n            message_size -= bytes_written;\n        }\n    } while (bytes_written >= 0 && message_size != 0);\n\ncleanup:\n    free(heap_buffer);\n    va_end(args_copy);\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/matmul.metal",
    "content": "#include <metal_atomic>\n#include <metal_compute>\n#include <metal_integer>\n#include <metal_math>\n#include <metal_simdgroup>\n#include <metal_stdlib>\n\n#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n\n// Each simdgroup reduces all channels of the input and computes a single channel of the output\n// + Efficient synchronization\n// + Sequential memory access within a warp\n// Each threadgroup computes (simdgroups_per_threadgroup) consecutive output channels\n// + Reuse input vector from threadgroup memory\n// + Avoid synchronization across warps when doing reduction\n\nkernel void gptoss_f32_bf16w_matmul(\n    constant gptoss_matmul_args& args [[ buffer(0) ]],\n    const device float4* input [[ buffer(1) ]],\n    const device bfloat4* weight [[ buffer(2) ]],\n    const device bfloat* bias [[ buffer(3) ]],\n    device float* output [[ buffer(4) ]],\n    const device gptoss_control* control [[ buffer(5) ]],\n    uint2 gid [[threadgroup_position_in_grid]],\n    uint simdgroup_tid [[thread_index_in_simdgroup]],\n    uint simdgroup_idx [[simdgroup_index_in_threadgroup]],\n    uint num_simdgroups [[simdgroups_per_threadgroup]])\n{\n    const uint simdgroup_size = 32;\n    if (control->abort != 0) {\n        return;\n    }\n\n    const uint num_column_vecs = args.num_column_vecs;\n    const uint row = gid.x * num_simdgroups + simdgroup_idx;\n\n    input += gid.y * num_column_vecs + simdgroup_tid;\n    weight += num_column_vecs * row + simdgroup_tid;\n    bias += row;\n    output += gid.y * args.num_rows + row;\n\n    uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;\n\n    float4 sum4 = 0.0f;\n    do {\n        const bfloat4 w = *weight;\n        const float4 i = *input;\n        sum4 = metal::fma(static_cast<float4>(w), i, sum4);\n\n        weight += simdgroup_size;\n        input += simdgroup_size;\n    } while (--num_iter != 0);\n    const float2 sum2 = sum4.xy + sum4.zw;\n    float sum = sum2.x + sum2.y;\n    sum = metal::simd_sum(sum);\n    if (metal::simd_is_first()) {\n        sum += static_cast<float>(*bias);\n        if (args.add) {\n            *output += sum;\n        } else {\n            *output = sum;\n        }\n    }\n}\n\nkernel void gptoss_f32_bf16w_matmul_qkv(\n    constant gptoss_qkv_args& args [[ buffer(0) ]],\n    const device float4* input [[ buffer(1) ]],\n    const device bfloat4* weight [[ buffer(2) ]],\n    const device bfloat* bias [[ buffer(3) ]],\n    device float* q [[ buffer(4) ]],\n    device float* kv [[ buffer(5) ]],\n    const device gptoss_control* control [[ buffer(6) ]],\n    threadgroup void* scratch [[ threadgroup(0) ]],\n    uint2 gid [[threadgroup_position_in_grid]],\n    uint simdgroup_tid [[thread_index_in_simdgroup]],\n    uint simdgroup_idx [[simdgroup_index_in_threadgroup]],\n    uint num_simdgroups [[simdgroups_per_threadgroup]])\n{\n    const uint simdgroup_size = 32;\n    const uint head_dim = 64;\n    const uint num_q_heads = 64;\n    const uint num_kv_heads = 8;\n    if (control->abort != 0) {\n        return;\n    }\n\n    const uint num_column_vecs = args.num_column_vecs;\n    const uint row = gid.x * num_simdgroups + simdgroup_idx;\n\n    input += gid.y * num_column_vecs + simdgroup_tid;\n    weight += num_column_vecs * row + simdgroup_tid;\n    bias += row;\n    q += gid.y * args.num_rows;\n\n    uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;\n\n    float4 sum4 = 0.0f;\n    do {\n        const bfloat4 w = *weight;\n        const float4 i = *input;\n        sum4 = metal::fma(static_cast<float4>(w), i, sum4);\n\n        weight += simdgroup_size;\n        input += simdgroup_size;\n    } while (--num_iter != 0);\n    const float2 sum2 = sum4.xy + sum4.zw;\n    float sum = sum2.x + sum2.y;\n    sum = metal::simd_sum(sum);\n    if (metal::simd_is_first()) {\n        sum += static_cast<float>(*bias);\n        static_cast<threadgroup float*>(scratch)[simdgroup_idx] = sum;\n    }\n    metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    if (simdgroup_idx == 0) {\n        const uint num_half_simdgroups = num_simdgroups / 2;\n        if (simdgroup_tid < num_half_simdgroups) {\n            float2 vals = static_cast<const threadgroup float2*>(scratch)[simdgroup_tid];\n            const uint idx = gid.x * num_half_simdgroups + simdgroup_tid;\n            const uint head_idx = idx / (head_dim / 2);\n            const uint token_idx = args.token_offset + gid.y;\n            const uint dim_idx = idx % (head_dim / 2);\n            if (head_idx < num_q_heads + num_kv_heads) {\n                const float dim_idx_val = static_cast<float>(dim_idx);\n                const float inv_extrapolation_freq = metal::precise::exp(dim_idx_val * args.freq_scale);\n                const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;\n                const float alpha = metal::saturate(metal::fma(dim_idx_val, args.yarn_scale, args.yarn_offset));\n                const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);\n\n                const float phi = static_cast<float>(token_idx) * inv_freq;\n                const float yarn_multiplier = args.yarn_multiplier;\n                float cosphi;\n                const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;\n                cosphi *= yarn_multiplier;\n\n                vals = (float2) {\n                    vals.x * cosphi - vals.y * sinphi,\n                    vals.x * sinphi + vals.y * cosphi,\n                };\n            }\n            if (head_idx < num_q_heads) {\n                reinterpret_cast<device float2*>(q)[idx] = vals;\n            } else if (head_idx < num_q_heads + num_kv_heads) {\n                const uint h = head_idx - num_q_heads;\n                reinterpret_cast<device float2*>(kv + (h * args.max_tokens + token_idx) * 2 * head_dim)[dim_idx] = vals;\n            } else {\n                const uint h = head_idx - num_q_heads - num_kv_heads;\n                reinterpret_cast<device float2*>(kv + (h * args.max_tokens + token_idx) * 2 * head_dim + head_dim)[dim_idx] = vals;\n            }\n        }\n    }\n}\n\nkernel void gptoss_f32_bf16w_unembedding(\n    constant gptoss_unembedding_args& args [[ buffer(0) ]],\n    const device float4* input [[ buffer(1) ]],\n    const device bfloat4* weight [[ buffer(2) ]],\n    device float* output [[ buffer(3) ]],\n    device metal::atomic_ulong* argmax [[ buffer(4) ]],\n    const device gptoss_control* control [[ buffer(5) ]],\n    uint2 gid [[threadgroup_position_in_grid]],\n    uint simdgroup_tid [[thread_index_in_simdgroup]],\n    uint simdgroup_idx [[simdgroup_index_in_threadgroup]],\n    uint num_simdgroups [[simdgroups_per_threadgroup]])\n{\n    const uint simdgroup_size = 32;\n    threadgroup uint2 threadgroup_buffer[32];\n    if (control->abort != 0) {\n        return;\n    }\n\n    const uint num_column_vecs = args.num_column_vecs;\n    const uint row_start = gid.x * args.num_rows_per_threadgroup + simdgroup_idx;\n    const uint row_end = metal::min(gid.x * args.num_rows_per_threadgroup + args.num_rows_per_threadgroup, args.num_rows);\n    const uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;\n\n    input += gid.y * num_column_vecs + simdgroup_tid;\n    weight += num_column_vecs * row_start + simdgroup_tid;\n    output += gid.y * args.num_rows + row_start;\n\n    uint2 row_sum{0xFFFFFFFFul, 0xFFFFFFFFul};\n    for (uint row = row_start; row < row_end; row += num_simdgroups) {\n        uint n = num_iter;\n\n        float4 sum4 = 0.0f;\n        do {\n            const bfloat4 w = *weight;\n            const float4 i = *input;\n\n            sum4 = metal::fma(static_cast<float4>(w), i, sum4);\n\n            weight += simdgroup_size;\n            input += simdgroup_size;\n        } while (--n != 0);\n        input -= num_iter * simdgroup_size;\n        weight -= num_iter * simdgroup_size;\n\n        const float2 sum2 = sum4.xy + sum4.zw;\n        float sum = sum2.x + sum2.y;\n        sum = metal::simd_sum(sum);\n        uint sum_bits = as_type<uint>(sum);\n        if (static_cast<int>(sum_bits) >= 0) {\n            sum_bits ^= 0x7FFFFFFFu;\n        }\n        row_sum = as_type<uint2>(metal::min(as_type<ulong>(row_sum), as_type<ulong>(uint2{row, sum_bits})));\n        if (metal::simd_is_first()) {\n            *output = sum;\n        }\n\n        weight += num_column_vecs * num_simdgroups;\n        output += num_simdgroups;\n    }\n    if (metal::simd_is_first()) {\n        threadgroup_buffer[simdgroup_idx] = row_sum;\n    }\n    metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    if (simdgroup_idx == 0) {\n        // Min-Reduce threadgroup_buffer\n        if (simdgroup_tid < num_simdgroups) {\n            row_sum = threadgroup_buffer[simdgroup_tid];\n        }\n        const uint sum_bits = row_sum.y;\n        const uint sum_bits_min = metal::simd_min(sum_bits);\n        const uint row_min = metal::simd_min(sum_bits == sum_bits_min ? row_sum.x : 0xFFFFFFFFu);\n        if (metal::simd_is_first()) {\n            const uint2 threadgroup_output{row_min, sum_bits_min};\n            atomic_min_explicit(&argmax[gid.y], as_type<ulong>(threadgroup_output), metal::memory_order_relaxed);\n        }\n    }\n}\n\n// Current constraints for the dense matmul kernel:\n//  1- All B* and Sg_* are a multiple of 8.\n//  2- Bm is divisible by Sg_n and Bn is divisible by Sg_n.\n//  3- M, N and K are all divisible by 8..\ntemplate <uint Bm, uint Bn, uint Bk, uint Sg_Bm, uint Sg_Bn, uint add = 0>\ninline void _gptoss_f32_bf16w_dense_matmul_impl(\n    constant gptoss_dense_matmul_args& args, const device float* lhs,\n    const device bfloat* rhs, const device bfloat* __restrict__ bias,\n    device float* out, const device gptoss_control* control, threadgroup float* scratch, threadgroup float* bias_tile,\n    uint sg_id, uint sg_count_per_tg, uint3 gid, uint3 tg_id, uint3 local_tid,\n    uint3 threadgroup_size) {\n\n    if (control->abort != 0) {\n        return;\n    }\n\n    // The kernel assumes that M, K, and N are divisible by 8.\n    const uint M = args.m;\n    const uint K = args.k;\n    const uint N = args.n;\n    static_assert((Bm % 8u) == 0u, \"Bm must be a multiple of 8\");\n    static_assert((Bn % 8u) == 0u, \"Bn must be a multiple of 8\");\n    static_assert((Bk % 8u) == 0u, \"Bk must be a multiple of 8\");\n    static_assert((Sg_Bm % 8u) == 0u, \"Bk must be a multiple of 8\");\n    static_assert((Sg_Bn % 8u) == 0u, \"Bk must be a multiple of 8\");\n    static_assert((Bn % Sg_Bn) == 0u, \"Bn must be a multiple of Sg_Bn\");\n    static_assert((Bm % Sg_Bm) == 0u, \"Bm must be a multiple of Sg_Bm\");\n\n    // Get row and col tg.\n    const uint row_tg = tg_id.y;\n    const uint col_tg = tg_id.x;\n    // Get row and col local tid.\n    const uint row_tg_offset = row_tg * Bm;\n    const uint col_tg_offset = col_tg * Bn;\n\n    const uint sg_col_count = Bn / Sg_Bn;\n    const uint row_sg = sg_id / sg_col_count;\n    const uint col_sg = sg_id % sg_col_count;\n\n    const uint row_sg_offset = row_sg * Sg_Bm;\n    const uint col_sg_offset = col_sg * Sg_Bn;\n    constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);\n    // Create an array of simdgroup_float8x8 to hold temp results.\n    metal::simdgroup_float8x8 OutTiles[temp_result_size];\n#pragma clang loop unroll(full)\n    for (uint i = 0; i < temp_result_size; i++) {\n        OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(\n            static_cast<float>(0.0));\n    }\n\n    for (uint k_offset = 0; k_offset < K; k_offset += Bk) {\n#pragma clang loop unroll(full)\n        for (uint k = 0; k < Bk; k += 8) {\n#pragma clang loop unroll(full)\n            for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {\n                // const uint m_subtile = row_sg_offset + m_subtile_;\n                // const uint row_index_in_out_tile = (m_subtile - row_sg_offset) / 8;\n                const uint row_index_in_out_tile = m_subtile_ / 8;\n                metal::simdgroup_float8x8 LHStile;\n                const uint k_id = k + k_offset;\n                const uint row_offset = row_tg_offset + row_sg_offset + m_subtile_;\n                metal::simdgroup_load(LHStile, lhs, K, ulong2(k_id, row_offset));\n                metal::simdgroup_bfloat8x8 RHStile;\n#pragma clang loop unroll(full)\n                for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {\n                    const uint col_index_in_out_tile = n_subtile_ / 8;\n                    const uint current_index_out_tile =\n                        row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;\n                    const uint col_offset = col_tg_offset + col_sg_offset + n_subtile_;\n                    simdgroup_load(RHStile, rhs, K, ulong2(k_id, col_offset), /*transpose=*/true);\n                    // If rhs was not transposed, use the following instead:\n                    // simdgroup_load(RHStile, rhs, N, ulong2(col_offset, k_id));\n                    simdgroup_multiply_accumulate(OutTiles[current_index_out_tile],\n                                                  LHStile, RHStile,\n                                                  OutTiles[current_index_out_tile]);\n                }\n            }\n        }\n    }\n    // Epilogue.\n#pragma clang loop unroll(full)\n    for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {\n        const uint col_index_in_out_tile = n_subtile_ / 8;\n        const uint local_col_offset = col_sg_offset + n_subtile_;\n#pragma clang loop unroll(full)\n        for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {\n            const uint row_index_in_out_tile = m_subtile_ / 8;\n            const uint local_row_offset = row_sg_offset + m_subtile_;\n            const uint current_index_out_tile =\n                row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;\n            simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,\n                            ulong2(local_col_offset, local_row_offset));\n        }\n    }\n    // TODO(ibahmed): vectorize these loads an maybe unroll the loop.\n    const uint thread_count_per_tg =\n        threadgroup_size.x * threadgroup_size.y * threadgroup_size.z;\n    for (uint c_local = local_tid.x; c_local < Bn;\n         c_local += thread_count_per_tg) {\n        const uint c_global = col_tg_offset + c_local;\n        bias_tile[c_local] =\n            (c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;\n    }\n\n    metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n\n    // TODO(ibahmed): vectorize these stores and maybe unroll the loop.\n    for (uint idx = local_tid.x; idx < Bm * Bn; idx += thread_count_per_tg) {\n        const uint r = idx / Bn;\n        const uint c = idx % Bn;\n\n        const uint out_row = row_tg_offset + r;\n        const uint out_col = col_tg_offset + c;\n\n        if (out_row < M && out_col < N) {\n            float acc = scratch[idx] + bias_tile[c];\n            if (add) {\n                acc += out[out_row * N + out_col];\n            }\n            out[out_row * N + out_col] = acc;\n        }\n    }\n}\n\nkernel void gptoss_f32_bf16w_dense_matmul_qkv(\n    constant gptoss_dense_matmul_qkv_args& args [[buffer(0)]],\n    const device float* lhs [[buffer(1)]],\n    const device bfloat* rhs [[buffer(2)]],\n    const device bfloat* __restrict__ bias [[buffer(3)]],\n    device float* out [[buffer(4)]],\n    device float* kv [[buffer(5)]],\n    const device gptoss_control* control [[buffer(6)]],\n    uint sg_id [[simdgroup_index_in_threadgroup]],\n    uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],\n    uint3 gid [[thread_position_in_grid]],\n    uint3 tg_id [[threadgroup_position_in_grid]],\n    uint3 local_tid [[thread_position_in_threadgroup]],\n    uint3 threadgroup_size [[threads_per_threadgroup]]) {\n    threadgroup float scratch[QKV_Bm * QKV_Bn];\n    threadgroup float bias_tile[QKV_Bn];\n    if (control->abort != 0) {\n        return;\n    }\n\n    // The kernel assumes that QKV_Bm, QKV_Bn, QKV_Bk, QKV_Sg_Bm, QKV_Sg_Bn are divisible by 8.\n    const uint M = args.m;\n    const uint K = args.k;\n    const uint N = args.n;\n    const uint Bm = QKV_Bm;\n    const uint Bn = QKV_Bn;\n    const uint Bk = QKV_Bk;\n    const uint Sg_Bm = QKV_Sg_Bm;\n    const uint Sg_Bn = QKV_Sg_Bn;\n    static_assert((Bm % 8u) == 0u, \"Bm must be a multiple of 8\");\n    static_assert((Bn % 8u) == 0u, \"Bn must be a multiple of 8\");\n    static_assert((Bk % 8u) == 0u, \"Bk must be a multiple of 8\");\n    static_assert((Sg_Bm % 8u) == 0u, \"Bk must be a multiple of 8\");\n    static_assert((Sg_Bn % 8u) == 0u, \"Bk must be a multiple of 8\");\n    static_assert((Bn % Sg_Bn) == 0u, \"Bn must be a multiple of Sg_Bn\");\n    static_assert((Bm % Sg_Bm) == 0u, \"Bm must be a multiple of Sg_Bm\");\n\n    // Get row and col tg.\n    const uint row_tg = tg_id.y;\n    const uint col_tg = tg_id.x;\n    // Get row and col local tid.\n    const uint row_tg_offset = row_tg * Bm;\n    const uint col_tg_offset = col_tg * Bn;\n\n    const uint sg_col_count = Bn / Sg_Bn;\n    const uint row_sg = sg_id / sg_col_count;\n    const uint col_sg = sg_id % sg_col_count;\n\n    const uint row_sg_offset = row_sg * Sg_Bm;\n    const uint col_sg_offset = col_sg * Sg_Bn;\n    constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);\n    // Create an array of simdgroup_float8x8 to hold temp results.\n    metal::simdgroup_float8x8 OutTiles[temp_result_size];\n#pragma clang loop unroll(full)\n    for (uint i = 0; i < temp_result_size; i++) {\n        OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(\n            static_cast<float>(0.0));\n    }\n\n    for (uint k_offset = 0; k_offset < K; k_offset += Bk) {\n#pragma clang loop unroll(full)\n        for (uint k = 0; k < Bk; k += 8) {\n#pragma clang loop unroll(full)\n            for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {\n                const uint row_index_in_out_tile = m_subtile_ / 8;\n                metal::simdgroup_float8x8 LHStile;\n                const uint k_id = k + k_offset;\n                const uint row_offset = row_tg_offset + row_sg_offset + m_subtile_;\n                metal::simdgroup_load(LHStile, lhs, K, ulong2(k_id, row_offset));\n                metal::simdgroup_bfloat8x8 RHStile;\n#pragma clang loop unroll(full)\n                for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {\n                    const uint col_index_in_out_tile = n_subtile_ / 8;\n                    const uint current_index_out_tile =\n                        row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;\n                    const uint col_offset = col_tg_offset + col_sg_offset + n_subtile_;\n                    simdgroup_load(RHStile, rhs, K, ulong2(k_id, col_offset), /*transpose=*/true);\n                    simdgroup_multiply_accumulate(OutTiles[current_index_out_tile],\n                                                  LHStile, RHStile,\n                                                  OutTiles[current_index_out_tile]);\n                }\n            }\n        }\n    }\n    // Epilogue.\n#pragma clang loop unroll(full)\n    for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {\n        const uint col_index_in_out_tile = n_subtile_ / 8;\n        const uint local_col_offset = col_sg_offset + n_subtile_;\n#pragma clang loop unroll(full)\n        for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {\n            const uint row_index_in_out_tile = m_subtile_ / 8;\n            const uint local_row_offset = row_sg_offset + m_subtile_;\n            const uint current_index_out_tile =\n                row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;\n            simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,\n                            ulong2(local_col_offset, local_row_offset));\n        }\n    }\n    // TODO(ibahmed): vectorize these loads an maybe unroll the loop.\n    const uint thread_count_per_tg =\n        threadgroup_size.x * threadgroup_size.y * threadgroup_size.z;\n    for (uint c_local = local_tid.x; c_local < Bn;\n         c_local += thread_count_per_tg) {\n        const uint c_global = col_tg_offset + c_local;\n        bias_tile[c_local] =\n            (c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;\n    }\n\n    metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    const uint q_heads = 64;\n    const uint kv_heads = 8;\n    const uint head_dim = 64;\n    const uint q_cols = q_heads * head_dim;\n    const uint k_cols = kv_heads * head_dim;\n\n    // TODO(ibahmed): vectorize these stores and maybe unroll the loop.\n    for (uint idx = local_tid.x; idx < Bm * Bn; idx += thread_count_per_tg) {\n        const uint r = idx / Bn;\n        const uint c = idx % Bn;\n\n        const uint out_row = row_tg_offset + r;\n        const uint out_col = col_tg_offset + c;\n\n        if (out_row < M && out_col < N) {\n            float acc = scratch[idx] + bias_tile[c];\n            if ((out_col < q_cols + k_cols)) {\n                out[out_row * N + out_col] = acc;\n            } else {\n                // Write v into kv cache.\n                const uint v_col = out_col - q_cols - k_cols;\n                const uint v_head = v_col / head_dim;\n                const uint dim_idx = v_col % head_dim;\n                const uint token_idx = args.token_offset + out_row;\n                kv[(v_head * args.max_tokens + token_idx) * 2 * head_dim + head_dim + dim_idx] = acc;\n            }\n        }\n    }\n}\n\nkernel void gptoss_f32_bf16w_dense_matmul_attn_output(\n    constant gptoss_dense_matmul_args& args [[buffer(0)]],\n    const device float* lhs [[buffer(1)]],\n    const device bfloat* rhs [[buffer(2)]],\n    const device bfloat* __restrict__ bias [[buffer(3)]],\n    device float* out [[buffer(4)]],\n    const device gptoss_control* control [[buffer(5)]],\n    uint sg_id [[simdgroup_index_in_threadgroup]],\n    uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],\n    uint3 gid [[thread_position_in_grid]],\n    uint3 tg_id [[threadgroup_position_in_grid]],\n    uint3 local_tid [[thread_position_in_threadgroup]],\n    uint3 threadgroup_size [[threads_per_threadgroup]]) {\n    threadgroup float scratch[ATTN_OUTPUT_Bm * ATTN_OUTPUT_Bn];\n    threadgroup float bias_tile[ATTN_OUTPUT_Bn];\n    _gptoss_f32_bf16w_dense_matmul_impl<ATTN_OUTPUT_Bm, ATTN_OUTPUT_Bn,\n                                        ATTN_OUTPUT_Bk, ATTN_OUTPUT_Sg_Bm,\n                                        ATTN_OUTPUT_Sg_Bn, /*add=*/1>(\n        args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,\n        gid, tg_id, local_tid, threadgroup_size);\n}\n\nkernel void gptoss_f32_bf16w_dense_matmul_mlp_gate(\n    constant gptoss_dense_matmul_args& args [[buffer(0)]],\n    const device float* lhs [[buffer(1)]],\n    const device bfloat* rhs [[buffer(2)]],\n    const device bfloat* __restrict__ bias [[buffer(3)]],\n    device float* out [[buffer(4)]],\n    const device gptoss_control* control [[buffer(5)]],\n    uint sg_id [[simdgroup_index_in_threadgroup]],\n    uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],\n    uint3 gid [[thread_position_in_grid]],\n    uint3 tg_id [[threadgroup_position_in_grid]],\n    uint3 local_tid [[thread_position_in_threadgroup]],\n    uint3 threadgroup_size [[threads_per_threadgroup]]) {\n    threadgroup float scratch[MLP_GATE_Bm * MLP_GATE_Bn];\n    threadgroup float bias_tile[MLP_GATE_Bn];\n    _gptoss_f32_bf16w_dense_matmul_impl<MLP_GATE_Bm, MLP_GATE_Bn, MLP_GATE_Bk,\n                                        MLP_GATE_Sg_Bm, MLP_GATE_Sg_Bn>(\n        args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,\n        gid, tg_id, local_tid, threadgroup_size);\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/metal-kernels.c",
    "content": "#include <inttypes.h>\n#include <stddef.h>\n#include <stdint.h>\n#include <math.h>\n\n#include <internal/kernel-args.h>\n#include <internal/log.h>\n#include <internal/math.h>\n#include <internal/metal.h>\n#include <internal/metal-kernels.h>\n\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* u32_fill_random_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint64_t num_elements,\n    uint64_t rng_seed,\n    uint64_t rng_offset)\n{\n    if (command_buffer->object == NULL || u32_fill_random_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = u32_fill_random_fn->max_threadgroup_threads;\n    } else if (threadgroup_size > u32_fill_random_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_vecs = num_elements;\n    const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;\n    const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));\n    const struct gptoss_u32_fill_random_args args = {\n        .num_vecs = num_vecs,\n        .num_vecs_per_threadgroup = num_vecs_per_threadgroup,\n        .seed = rng_seed,\n        .offset = rng_offset,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, u32_fill_random_fn,\n        threadgroup_size, 1, 1,\n        num_threadgroups, 1, 1,\n        sizeof(args), &args,\n        1, &output_buffer, &output_offset,\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_fill_random_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint64_t num_elements,\n    uint64_t rng_seed,\n    uint64_t rng_offset,\n    float rng_min,\n    float rng_max)\n{\n    if (command_buffer->object == NULL || f32_fill_random_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = f32_fill_random_fn->max_threadgroup_threads;\n    } else if (threadgroup_size > f32_fill_random_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n\n    if (rng_min >= rng_max) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_vecs = num_elements;\n    const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;\n    const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));\n    const struct gptoss_f32_fill_random_args args = {\n        .num_vecs = num_vecs,\n        .num_vecs_per_threadgroup = num_vecs_per_threadgroup,\n        .seed = rng_seed,\n        .offset = rng_offset,\n        .scale = (rng_max - rng_min) * 0x1.0p-32f,\n        .bias = (rng_min + rng_max) * 0.5f,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_fill_random_fn,\n        threadgroup_size, 1, 1,\n        num_threadgroups, 1, 1,\n        sizeof(args), &args,\n        1, &output_buffer, &output_offset,\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* bf16_fill_random_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint64_t num_elements,\n    uint64_t rng_seed,\n    uint64_t rng_offset,\n    float rng_min,\n    float rng_max)\n{\n    if (command_buffer->object == NULL || bf16_fill_random_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = bf16_fill_random_fn->max_threadgroup_threads;\n    } else if (threadgroup_size > bf16_fill_random_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n\n    if (rng_min >= rng_max) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_vecs = num_elements;\n    const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;\n    const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));\n    const struct gptoss_f32_fill_random_args args = {\n        .num_vecs = num_vecs,\n        .num_vecs_per_threadgroup = num_vecs_per_threadgroup,\n        .seed = rng_seed,\n        .offset = rng_offset,\n        .scale = (rng_max - rng_min) * 0x1.0p-32f,\n        .bias = (rng_min + rng_max) * 0.5f,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, bf16_fill_random_fn,\n        threadgroup_size, 1, 1,\n        num_threadgroups, 1, 1,\n        sizeof(args), &args,\n        1, &output_buffer, &output_offset,\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* mf4_f32_convert_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* block_buffer,\n    const struct gptoss_metal_buffer* scale_buffer,\n    const struct gptoss_metal_buffer* output_buffer,\n    uint64_t num_elements)\n{\n    if (command_buffer->object == NULL || mf4_f32_convert_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_elements % 32 != 0) {\n        return gptoss_status_invalid_argument;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = mf4_f32_convert_fn->max_threadgroup_threads;\n    } else if (threadgroup_size > mf4_f32_convert_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_vecs = num_elements / 32;\n    const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;\n    const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));\n    const struct gptoss_convert_args args = {\n        .num_vecs = num_vecs,\n        .num_vecs_per_threadgroup = num_vecs_per_threadgroup,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, mf4_f32_convert_fn,\n        threadgroup_size, 1, 1,\n        num_threadgroups, 1, 1,\n        sizeof(args), &args,\n        3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL,\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* bf16_f32_embeddings_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* token_buffer,\n    size_t token_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_channels)\n{\n    if (command_buffer->object == NULL || bf16_f32_embeddings_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_channels % 4 != 0) {\n        return gptoss_status_invalid_argument;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = bf16_f32_embeddings_fn->max_threadgroup_threads;\n    } else if (threadgroup_size > bf16_f32_embeddings_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const uint32_t num_vecs = num_channels / 4;\n    const struct gptoss_embeddings_args args = {\n        .num_vecs = num_vecs,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, bf16_f32_embeddings_fn,\n        threadgroup_size, 1, 1,\n        num_tokens, 1, 1,\n        sizeof(args), &args,\n        4,\n        (const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer, control_buffer},\n        (const size_t[]) {token_offset, weight_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_rmsnorm_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_channels,\n    float epsilon)\n{\n    if (command_buffer->object == NULL || f32_bf16w_rmsnorm_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_channels % 4 != 0) {\n        return gptoss_status_invalid_argument;\n    }\n\n    if (f32_bf16w_rmsnorm_fn->max_threadgroup_threads < 1024) {\n        return gptoss_status_unsupported_system;\n    }\n\n    if (f32_bf16w_rmsnorm_fn->simdgroup_threads != 32) {\n        return gptoss_status_unsupported_system;\n    }\n\n    const uint32_t num_vecs = num_channels / 4;\n    const struct gptoss_rmsnorm_args args = {\n        .num_vecs = num_vecs,\n        .num_channels = (float) num_channels,\n        .epsilon = epsilon,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_bf16w_rmsnorm_fn,\n        /*threadgroup_size=*/1024, 1, 1,\n        num_tokens, 1, 1,\n        sizeof(args), &args,\n        4,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, control_buffer},\n        (const size_t[]) {input_offset, weight_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_matmul_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows)\n{\n    if (command_buffer->object == NULL || f32_bf16w_matmul_fn->pipeline_state_object == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul kernel launch: invalid command buffer or pipeline state object\");\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = f32_bf16w_matmul_fn->simdgroup_threads;\n    } else if (threadgroup_size > f32_bf16w_matmul_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)\",\n            threadgroup_size, f32_bf16w_matmul_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n\n    if (num_cols % 4 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul kernel launch: number of columns (%\" PRIu32 \") is not divisible by 4\",\n            num_cols);\n        return gptoss_status_invalid_argument;\n    }\n    const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_fn->simdgroup_threads;\n    if (num_rows % num_simdgroups != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul kernel launch: number of rows (%\" PRIu32 \") is not divisible by the number of simdgroups (%zu)\",\n            num_rows, num_simdgroups);\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_matmul_args args = {\n        .num_column_vecs = num_cols / 4,\n        .num_rows = num_rows,\n        .add = 0,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_bf16w_matmul_fn,\n        threadgroup_size, 1, 1,\n        num_rows / num_simdgroups, num_tokens, 1,\n        sizeof(args), &args,\n        5,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},\n        (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* kv_buffer,\n    size_t kv_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_q_heads,\n    uint32_t num_kv_heads,\n    uint32_t attn_head_dim,\n    uint32_t token_offset,\n    uint32_t max_tokens,\n    float rope_base,\n    float interpolation_scale,\n    float yarn_offset,\n    float yarn_scale,\n    float yarn_multiplier)\n{\n    if (command_buffer->object == NULL || f32_bf16w_matmul_qkv_fn->pipeline_state_object == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_qkv kernel launch: invalid command buffer or pipeline state object\");\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = f32_bf16w_matmul_qkv_fn->simdgroup_threads;\n    } else if (threadgroup_size > f32_bf16w_matmul_qkv_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_qkv kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)\",\n            threadgroup_size, f32_bf16w_matmul_qkv_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n\n    if (num_cols % 4 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_qkv kernel launch: number of columns (%\" PRIu32 \") is not divisible by 4\",\n            num_cols);\n        return gptoss_status_invalid_argument;\n    }\n\n    if (num_q_heads != 64) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_qkv kernel launch: number of Q heads (%\" PRIu32 \") must be 64\",\n            num_q_heads);\n        return gptoss_status_invalid_argument;\n    }\n    if (num_kv_heads != 8) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_qkv kernel launch: number of KV heads (%\" PRIu32 \") must be 8\",\n            num_kv_heads);\n        return gptoss_status_invalid_argument;\n    }\n    if (attn_head_dim != 64) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_qkv kernel launch: attention head dimension (%\" PRIu32 \") must be 64\",\n            attn_head_dim);\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_qkv_fn->simdgroup_threads;\n    const uint32_t num_rows = (num_q_heads + 2 * num_kv_heads) * attn_head_dim;\n    if (num_rows % num_simdgroups != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_qkv kernel launch: number of rows (%\" PRIu32 \") is not divisible by the number of simdgroups (%zu)\",\n            num_rows, num_simdgroups);\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_qkv_args args = {\n        .num_column_vecs = num_cols / 4,\n        .num_rows = num_rows,\n        .token_offset = token_offset,\n        .freq_scale = -logf(rope_base) / (float) (int32_t) (attn_head_dim / 2),\n        .interpolation_scale = interpolation_scale,\n        .yarn_offset = yarn_offset,\n        .yarn_scale = yarn_scale,\n        .yarn_multiplier = yarn_multiplier,\n        .max_tokens = max_tokens,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_bf16w_matmul_qkv_fn,\n        threadgroup_size, 1, 1,\n        num_rows / num_simdgroups, num_tokens, 1,\n        sizeof(args), &args,\n        6,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, kv_buffer, control_buffer},\n        (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, kv_offset, control_offset},\n        /*threadgroup_buffer_size=*/num_simdgroups * sizeof(float));\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_matmul_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows)\n{\n    if (command_buffer->object == NULL || f32_bf16w_matmul_fn->pipeline_state_object == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_add kernel launch: invalid command buffer or pipeline state object\");\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = f32_bf16w_matmul_fn->simdgroup_threads;\n    } else if (threadgroup_size > f32_bf16w_matmul_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_add kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)\",\n            threadgroup_size, f32_bf16w_matmul_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n\n    if (num_cols % 4 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_add kernel launch: number of columns (%\" PRIu32 \") is not divisible by 4\",\n            num_cols);\n        return gptoss_status_invalid_argument;\n    }\n    const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_fn->simdgroup_threads;\n    if (num_rows % num_simdgroups != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_add kernel launch: number of rows (%\" PRIu32 \") is not divisible by the number of simdgroups (%zu)\",\n            num_rows, num_simdgroups);\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_matmul_args args = {\n        .num_column_vecs = num_cols / 4,\n        .num_rows = num_rows,\n        .add = 1,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_bf16w_matmul_fn,\n        threadgroup_size, 1, 1,\n        num_rows / num_simdgroups, num_tokens, 1,\n        sizeof(args), &args,\n        5,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},\n        (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset, \n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows,\n    uint32_t Bm,\n    uint32_t Bn,\n    uint32_t Bk,\n    uint32_t Sg_Bm,\n    uint32_t Sg_Bn)\n{\n\n    if (command_buffer->object == NULL || f32_bf16w_dense_matmul_fn->pipeline_state_object == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: invalid command buffer or pipeline state object\");\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_cols % 8 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: number of columns (%\" PRIu32 \") is not divisible by 8\",\n                         num_cols);\n        return gptoss_status_invalid_argument;\n    }\n    if (num_rows % 8 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: number of rows (%\" PRIu32 \") is not divisible by 8\",\n                         num_rows);\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_dense_matmul_args args = {\n        .m = num_tokens,\n        .n = num_rows,\n        .k = num_cols,\n    };\n    const size_t threads_per_simdgroup = f32_bf16w_dense_matmul_fn->simdgroup_threads;\n    const uint32_t m = args.m;\n    const uint32_t n = args.n;\n    const uint32_t k = args.k;\n    if (Bm % Sg_Bm != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: Bm (%\" PRIu32 \") is not divisible by Sg_Bm (%\" PRIu32 \")\",\n                         Bm, Sg_Bm);\n        return gptoss_status_invalid_argument;\n    }\n    if (Bn % Sg_Bn != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: Bn (%\" PRIu32 \") is not divisible by Sg_Bn (%\" PRIu32 \")\",\n                         Bn, Sg_Bn);\n        return gptoss_status_invalid_argument;\n    }\n    const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;\n    const size_t threadgroup_size_y = 1;\n    const size_t threadgroup_size_z = 1;\n    const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;\n    if (total_threadgroup_size > f32_bf16w_dense_matmul_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)\",\n                         total_threadgroup_size, f32_bf16w_dense_matmul_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n    if (n % Bn != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: n (%\" PRIu32 \") is not divisible by Bn (%\" PRIu32 \")\",\n                         n, Bn);\n        return gptoss_status_invalid_argument;\n    }\n    if (k % Bk != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: k (%\" PRIu32 \") is not divisible by Bk (%\" PRIu32 \")\",\n                         k, Bk);\n        return gptoss_status_invalid_argument;\n    }\n    const size_t grid_x = n / Bn;\n    const size_t grid_y = math_ceil_div(m, Bm);\n    const size_t grid_z = 1;\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_bf16w_dense_matmul_fn,\n        threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,\n        grid_x, grid_y, grid_z,\n        sizeof(args), &args,\n        5,\n        (const struct gptoss_metal_buffer *[]){input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},\n        (const size_t[]){input_offset, weight_offset, bias_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* kv_buffer,\n    size_t kv_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows,\n    uint32_t max_tokens,\n    uint32_t token_offset)\n{\n    if (command_buffer->object == NULL || f32_bf16w_dense_matmul_fn->pipeline_state_object == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: invalid command buffer or pipeline state object\");\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_cols % 8 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: number of columns (%\" PRIu32 \") is not divisible by 8\",\n                         num_cols);\n        return gptoss_status_invalid_argument;\n    }\n    if (num_rows % 8 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: number of rows (%\" PRIu32 \") is not divisible by 8\",\n                         num_rows);\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_dense_matmul_qkv_args args = {\n        .m = num_tokens,\n        .n = num_rows,\n        .k = num_cols,\n        .max_tokens = max_tokens,\n        .token_offset = token_offset,\n    };\n    const size_t threads_per_simdgroup = f32_bf16w_dense_matmul_fn->simdgroup_threads;\n    const uint32_t m = args.m;\n    const uint32_t n = args.n;\n    const uint32_t k = args.k;\n    if (QKV_Bm % QKV_Sg_Bm != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: Bm (%\" PRIu32 \") is not divisible by Sg_Bm (%\" PRIu32 \")\",\n                         QKV_Bm, QKV_Sg_Bm);\n        return gptoss_status_invalid_argument;\n    }\n    if (QKV_Bn % QKV_Sg_Bn != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: Bn (%\" PRIu32 \") is not divisible by Sg_Bn (%\" PRIu32 \")\",\n                         QKV_Bn, QKV_Sg_Bn);\n        return gptoss_status_invalid_argument;\n    }\n    const size_t threadgroup_size_x = (QKV_Bm / QKV_Sg_Bm) * (QKV_Bn / QKV_Sg_Bn) * threads_per_simdgroup;\n    const size_t threadgroup_size_y = 1;\n    const size_t threadgroup_size_z = 1;\n    const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;\n    if (total_threadgroup_size > f32_bf16w_dense_matmul_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)\",\n                         total_threadgroup_size, f32_bf16w_dense_matmul_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n    if (n % QKV_Bn != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: n (%\" PRIu32 \") is not divisible by Bn (%\" PRIu32 \")\",\n                         n, QKV_Bn);\n        return gptoss_status_invalid_argument;\n    }\n    if (k % QKV_Bk != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_dense_matmul kernel launch: k (%\" PRIu32 \") is not divisible by Bk (%\" PRIu32 \")\",\n                         k, QKV_Bk);\n        return gptoss_status_invalid_argument;\n    }\n    const size_t grid_x = n / QKV_Bn;\n    const size_t grid_y = math_ceil_div(m, QKV_Bm);\n    const size_t grid_z = 1;\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_bf16w_dense_matmul_fn,\n        threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,\n        grid_x, grid_y, grid_z,\n        sizeof(args), &args,\n        6,\n        (const struct gptoss_metal_buffer *[]){input_buffer, weight_buffer, bias_buffer, output_buffer, kv_buffer, control_buffer},\n        (const size_t[]){input_offset, weight_offset, bias_offset, output_offset, kv_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows)\n{\n    return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(\n        command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,\n        weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,\n        output_offset, control_buffer, control_offset, num_tokens, num_cols, num_rows, ATTN_OUTPUT_Bm,\n        ATTN_OUTPUT_Bn, ATTN_OUTPUT_Bk, ATTN_OUTPUT_Sg_Bm, ATTN_OUTPUT_Sg_Bn);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows)\n{\n    return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(\n        command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,\n        weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,\n        output_offset, control_buffer, control_offset, num_tokens, num_cols,\n        num_rows, MLP_GATE_Bm, MLP_GATE_Bn, MLP_GATE_Bk, MLP_GATE_Sg_Bm,\n        MLP_GATE_Sg_Bn);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_bf16w_unembedding_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_buffer,\n    size_t weight_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* argmax_buffer,\n    size_t argmax_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_cols,\n    uint32_t num_rows)\n{\n    if (command_buffer->object == NULL || f32_bf16w_unembedding_fn->pipeline_state_object == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_unembedding kernel launch: invalid command buffer or pipeline state object\");\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = f32_bf16w_unembedding_fn->simdgroup_threads;\n    } else if (threadgroup_size > f32_bf16w_unembedding_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_unembedding kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)\",\n            threadgroup_size, f32_bf16w_unembedding_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n\n    if (num_cols % 4 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_unembedding kernel launch: number of columns (%\" PRIu32 \") is not divisible by 4\",\n            num_cols);\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_simdgroups = threadgroup_size / f32_bf16w_unembedding_fn->simdgroup_threads;\n    const size_t num_rows_per_threadgroup = math_ceil_div(num_rows, max_threadgroups * num_simdgroups) * num_simdgroups;\n    const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_rows, num_rows_per_threadgroup));\n    const struct gptoss_unembedding_args args = {\n        .num_column_vecs = num_cols / 4,\n        .num_rows_per_threadgroup = num_rows_per_threadgroup,\n        .num_rows = num_rows,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_bf16w_unembedding_fn,\n        threadgroup_size, 1, 1,\n        num_threadgroups, num_tokens, 1,\n        sizeof(args), &args,\n        5,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer, control_buffer},\n        (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_mf4w_moe_matmul_swiglu_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_buffer,\n    size_t expert_offset,\n    const struct gptoss_metal_buffer* weight_block_buffer,\n    size_t weight_block_offset,\n    const struct gptoss_metal_buffer* weight_scale_buffer,\n    size_t weight_scale_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    float swiglu_limit,\n    uint32_t expert_stride,\n    uint32_t num_tokens,\n    uint32_t num_active_experts,\n    uint32_t num_cols,\n    uint32_t num_rows)\n{\n    if (command_buffer->object == NULL || f32_mf4w_moe_matmul_swiglu_fn->pipeline_state_object == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: invalid command buffer or pipeline state object\");\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = 2 * f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads;\n    } else if (threadgroup_size > f32_mf4w_moe_matmul_swiglu_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)\",\n            threadgroup_size, f32_mf4w_moe_matmul_swiglu_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    } else if (threadgroup_size % (2 * f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads)) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: threadgroup size (%zu) is not divisible by simdgroup size (%zu) multiplied by 2X\",\n            threadgroup_size, f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n\n    if (num_cols % 32 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: number of columns (%\" PRIu32 \") is not divisible by 32\",\n            num_cols);\n        return gptoss_status_invalid_argument;\n    }\n    const size_t num_simdgroups = threadgroup_size / f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads;\n    if ((2 * num_rows) % num_simdgroups != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_bf16w_matmul_add kernel launch: \"\n            \"the number of rows (%\" PRIu32 \") multiplied by 2X is not divisible by the number of simdgroups (%zu)\",\n            num_rows, num_simdgroups);\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_moe_matmul_swiglu_args args = {\n        .num_column_vecs = num_cols / 32,\n        .num_rows = num_rows,\n        .num_active_experts = num_active_experts,\n        .weight_expert_stride = expert_stride,\n        .output_expert_stride = num_rows * num_tokens,\n        .swiglu_min = -swiglu_limit,\n        .swiglu_max = swiglu_limit,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_mf4w_moe_matmul_swiglu_fn,\n        threadgroup_size, 1, 1,\n        (2 * num_rows) / num_simdgroups, num_tokens, num_active_experts,\n        sizeof(args), &args,\n        7,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},\n        (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_mf4w_moe_matmul_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_buffer,\n    size_t expert_offset,\n    const struct gptoss_metal_buffer* weight_block_buffer,\n    size_t weight_block_offset,\n    const struct gptoss_metal_buffer* weight_scale_buffer,\n    size_t weight_scale_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t expert_stride,\n    uint32_t num_tokens,\n    uint32_t num_active_experts,\n    uint32_t num_cols,\n    uint32_t num_rows)\n{\n    if (command_buffer->object == NULL || f32_mf4w_moe_matmul_fn->pipeline_state_object == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul kernel launch: invalid command buffer or pipeline state object\");\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = f32_mf4w_moe_matmul_fn->simdgroup_threads;\n    } else if (threadgroup_size > f32_mf4w_moe_matmul_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)\",\n            threadgroup_size, f32_mf4w_moe_matmul_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    } else if (threadgroup_size % f32_mf4w_moe_matmul_fn->simdgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul kernel launch: threadgroup size (%zu) is not divisible by simdgroup size (%zu)\",\n            threadgroup_size, f32_mf4w_moe_matmul_fn->simdgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n\n    if (num_cols % 32 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul kernel launch: number of columns (%\" PRIu32 \") is not divisible by 32\",\n            num_cols);\n        return gptoss_status_invalid_argument;\n    }\n    const size_t num_simdgroups = threadgroup_size / f32_mf4w_moe_matmul_fn->simdgroup_threads;\n    if (num_rows % num_simdgroups != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_matmul kernel launch: \"\n            \"the number of rows (%\" PRIu32 \") is not divisible by the number of simdgroups (%zu)\",\n            num_rows, num_simdgroups);\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_moe_matmul_args args = {\n        .num_column_vecs = num_cols / 32,\n        .num_rows = num_rows,\n        .num_active_experts = num_active_experts,\n        .input_expert_stride = num_tokens * (num_cols / 32),\n        .weight_expert_stride = expert_stride,\n        .output_expert_stride = num_rows * num_tokens,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_mf4w_moe_matmul_fn,\n        threadgroup_size, 1, 1,\n        num_rows / num_simdgroups, num_tokens, num_active_experts,\n        sizeof(args), &args,\n        7,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},\n        (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_rope_fn,\n    size_t threadgroup_size,\n    const struct gptoss_metal_buffer* activations_buffer,\n    size_t activations_offset,\n    const struct gptoss_metal_buffer* kv_buffer,\n    size_t kv_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    float rope_base,\n    float interpolation_scale,\n    float yarn_offset,\n    float yarn_scale,\n    float yarn_multiplier,\n    uint32_t num_tokens,\n    uint32_t num_q_heads,\n    uint32_t num_kv_heads,\n    uint32_t attn_head_dim,\n    uint32_t max_tokens,\n    uint32_t token_offset)\n{\n    if (command_buffer->object == NULL || f32_rope_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = f32_rope_fn->max_threadgroup_threads;\n    } else if (threadgroup_size > f32_rope_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_simdgroups = threadgroup_size / f32_rope_fn->simdgroup_threads;\n    const uint32_t num_qk_heads = num_q_heads + num_kv_heads;\n    if (num_qk_heads % num_simdgroups != 0) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_rope_args args = {\n        .token_stride = (num_q_heads + 2 * num_kv_heads) * (attn_head_dim / 2),\n        .token_offset = token_offset,\n        .max_tokens = max_tokens,\n        .freq_scale = -logf(rope_base) / (float) (int32_t) (attn_head_dim / 2),\n        .interpolation_scale = interpolation_scale,\n        .yarn_offset = yarn_offset,\n        .yarn_scale = yarn_scale,\n        .yarn_multiplier = yarn_multiplier,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_rope_fn,\n        threadgroup_size, 1, 1,\n        num_qk_heads / num_simdgroups, num_tokens, 1,\n        sizeof(args), &args,\n        3,\n        (const struct gptoss_metal_buffer *[]) {activations_buffer, kv_buffer, control_buffer},\n        (const size_t[]) {activations_offset, kv_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* expert_routing_metadata_fn,\n    const struct gptoss_metal_buffer* expert_predictions_buffer,\n    size_t expert_predictions_offset,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* intra_expert_offsets_buffer,\n    size_t intra_expert_offsets_offset,\n    uint32_t num_tokens,\n    uint32_t num_experts)\n{\n    if (command_buffer->object == NULL || expert_routing_metadata_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n    \n    const struct gptoss_expert_routing_metadata_args args = {\n        .tokens = num_tokens,\n        .num_experts = num_experts,\n    };\n    const uint32_t threadgroup_size = 256;\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, expert_routing_metadata_fn,\n        threadgroup_size, 1, 1,\n        /*num_threadgroups_x=*/1, /*num_threadgroups_y=*/1, /*num_threadgroups_z=*/1,\n        sizeof(args), &args,\n        3,\n        (const struct gptoss_metal_buffer *[]) {expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer},\n        (const size_t[]) {expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_scatter(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_scatter_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_predictions_buffer,\n    size_t expert_predictions_offset,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* intra_expert_offsets_buffer,\n    size_t intra_expert_offsets_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint32_t num_channels,\n    uint32_t num_tokens,\n    uint32_t num_active_experts)\n{\n    if (command_buffer->object == NULL || f32_scatter_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_channels % 4 != 0) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_vecs = num_channels / 4;\n    const size_t tgx = math_min(num_vecs, 64);\n    const size_t tgy = 1;\n    const size_t tgz = 1;\n    const size_t grid_x = math_ceil_div(num_vecs, tgx);\n    const size_t grid_y = num_tokens;\n    const size_t grid_z = 1;\n    const size_t total_threadgroup_size = tgx * tgy * tgz;\n    if (total_threadgroup_size > f32_scatter_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n    const struct gptoss_scatter_args args = {\n        .tokens = num_tokens,\n        .active_experts_per_token = num_active_experts,\n        .token_stride = num_channels,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_scatter_fn,\n        tgx, tgy, tgz,\n        grid_x, grid_y, grid_z,\n        sizeof(args), &args,\n        5,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer, output_buffer},\n        (const size_t[]) {input_offset, expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset, output_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_gather_and_accumulate_e4_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_predictions_buffer,\n    size_t expert_predictions_offset,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* intra_expert_offsets_buffer,\n    size_t intra_expert_offsets_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint32_t num_channels,\n    uint32_t num_tokens,\n    uint32_t num_active_experts) \n{\n        if (command_buffer->object == NULL || f32_gather_and_accumulate_e4_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_channels % 4 != 0) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_vecs = num_channels / 4;\n    const size_t tgx = math_min(num_vecs, 64);\n    const size_t tgy = 1;\n    const size_t tgz = 1;\n    const size_t grid_x = math_ceil_div(num_vecs, tgx);\n    const size_t grid_y = num_tokens;\n    const size_t grid_z = 1;\n    const size_t total_threadgroup_size = tgx * tgy * tgz;\n    if (total_threadgroup_size > f32_gather_and_accumulate_e4_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n    const struct gptoss_gather_args args = {\n        .tokens = num_tokens,\n        .active_experts_per_token = num_active_experts,\n        .token_stride = num_channels,\n    };\n    \n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_gather_and_accumulate_e4_fn,\n        tgx, tgy, tgz,\n        grid_x, grid_y, grid_z,\n        sizeof(args), &args,\n        5,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer, output_buffer},\n        (const size_t[]) {input_offset, expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset, output_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_swiglu_fn,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_block_buffer,\n    size_t weight_block_offset,\n    const struct gptoss_metal_buffer* weight_scale_buffer,\n    size_t weight_scale_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    float swiglu_limit,\n    uint32_t expert_stride_bytes,\n    uint32_t num_tokens,\n    uint32_t num_experts,\n    uint32_t num_cols,\n    uint32_t num_rows)\n{\n    if (command_buffer->object == NULL || f32_mf4w_moe_dense_matmul_swiglu_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_cols % 32 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: number of columns (%\" PRIu32 \") is not divisible by 32\",\n            num_cols);\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_moe_dense_matmul_swiglu_args args = {\n        .n = num_rows,\n        .k = num_cols,\n        .weight_blocks_expert_stride_bytes = expert_stride_bytes,\n        .weight_scales_expert_stride_bytes = expert_stride_bytes,\n        .bias_expert_stride_bytes = expert_stride_bytes,\n        .swiglu_min = -swiglu_limit,\n        .swiglu_max = swiglu_limit,\n    };\n    const size_t threads_per_simdgroup = f32_mf4w_moe_dense_matmul_swiglu_fn->simdgroup_threads;\n    const uint32_t m = num_tokens;\n    const uint32_t n = args.n;\n    const uint32_t k = args.k;\n    const uint32_t Bm = MOE_DENSE_MATMUL_SWIGLU_Bm;\n    const uint32_t Bn = MOE_DENSE_MATMUL_SWIGLU_Bn;\n    const uint32_t Bk = MOE_DENSE_MATMUL_SWIGLU_Bk;\n    const uint32_t Sg_Bm = MOE_DENSE_MATMUL_SWIGLU_Sg_Bm;\n    const uint32_t Sg_Bn = MOE_DENSE_MATMUL_SWIGLU_Sg_Bn;\n    if (Bm % Sg_Bm != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: Bm (%\" PRIu32 \") is not divisible by Sg_Bm (%\" PRIu32 \")\",\n            Bm, Sg_Bm);\n        return gptoss_status_invalid_argument;\n    }\n    if (Bn % Sg_Bn != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: Bn (%\" PRIu32 \") is not divisible by Sg_Bn (%\" PRIu32 \")\",\n            Bn, Sg_Bn);\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;\n    const size_t threadgroup_size_y = 1;\n    const size_t threadgroup_size_z = 1;\n    const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;\n    if (total_threadgroup_size > f32_mf4w_moe_dense_matmul_swiglu_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)\",\n            total_threadgroup_size, f32_mf4w_moe_dense_matmul_swiglu_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n    if (n % Bn != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: n (%\" PRIu32 \") is not divisible by Bn (%\" PRIu32 \")\",\n            n, Bn);\n        return gptoss_status_invalid_argument;\n    }\n    if (k % Bk != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: k (%\" PRIu32 \") is not divisible by Bk (%\" PRIu32 \")\",\n            k, Bk);\n        return gptoss_status_invalid_argument;\n    }\n    const size_t grid_x = n / Bn;\n    const size_t grid_y = math_ceil_div(m, Bm);\n    const size_t grid_z = num_experts;\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_mf4w_moe_dense_matmul_swiglu_fn,\n        threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,\n        grid_x, grid_y, grid_z,\n        sizeof(args), &args,\n        6,\n        (const struct gptoss_metal_buffer *[]) {expert_offsets_buffer, input_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},\n        (const size_t[]) {expert_offsets_offset, input_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},\n        /*threadgroup_buffer_size=*/0);\n\n    }\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_fn,\n    const struct gptoss_metal_buffer* expert_offsets_buffer,\n    size_t expert_offsets_offset,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* weight_block_buffer,\n    size_t weight_block_offset,\n    const struct gptoss_metal_buffer* weight_scale_buffer,\n    size_t weight_scale_offset,\n    const struct gptoss_metal_buffer* bias_buffer,\n    size_t bias_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    uint32_t expert_stride_bytes,\n    uint32_t num_tokens,\n    uint32_t num_experts,\n    uint32_t num_cols,\n    uint32_t num_rows)\n{\n    if (command_buffer->object == NULL || f32_mf4w_moe_dense_matmul_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_cols % 32 != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul kernel launch: number of columns (%\" PRIu32 \") is not divisible by 32\",\n            num_cols);\n        return gptoss_status_invalid_argument;\n    }\n    const struct gptoss_moe_dense_matmul_args args = {\n        .k = num_cols,\n        .n = num_rows,\n        .weight_blocks_expert_stride_bytes = expert_stride_bytes,\n        .weight_scales_expert_stride_bytes = expert_stride_bytes,\n        .bias_expert_stride_bytes = expert_stride_bytes,\n    };\n\n    const size_t threads_per_simdgroup = f32_mf4w_moe_dense_matmul_fn->simdgroup_threads;\n    const uint32_t m = num_tokens;\n    const uint32_t n = args.n;\n    const uint32_t k = args.k;\n    const uint32_t Bm = MOE_DENSE_MATMUL_Bm;\n    const uint32_t Bn = MOE_DENSE_MATMUL_Bn;\n    const uint32_t Bk = MOE_DENSE_MATMUL_Bk;\n    const uint32_t Sg_Bm = MOE_DENSE_MATMUL_Sg_Bm;\n    const uint32_t Sg_Bn = MOE_DENSE_MATMUL_Sg_Bn;\n    if (Bm % Sg_Bm != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul kernel launch: Bm (%\" PRIu32 \") is not divisible by Sg_Bm (%\" PRIu32 \")\",\n            Bm, Sg_Bm);\n        return gptoss_status_invalid_argument;\n    }\n    if (Bn % Sg_Bn != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul kernel launch: Bn (%\" PRIu32 \") is not divisible by Sg_Bn (%\" PRIu32 \")\",\n            Bn, Sg_Bn);\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;\n    const size_t threadgroup_size_y = 1;\n    const size_t threadgroup_size_z = 1;\n    const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;\n    if (total_threadgroup_size > f32_mf4w_moe_dense_matmul_fn->max_threadgroup_threads) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)\",\n            total_threadgroup_size, f32_mf4w_moe_dense_matmul_fn->max_threadgroup_threads);\n        return gptoss_status_invalid_argument;\n    }\n    if (n % Bn != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul kernel launch: n (%\" PRIu32 \") is not divisible by Bn (%\" PRIu32 \")\",\n            n, Bn);\n        return gptoss_status_invalid_argument;\n    }\n    if (k % Bk != 0) {\n        GPTOSS_LOG_ERROR(\"failed to encode f32_mf4w_moe_dense_matmul kernel launch: k (%\" PRIu32 \") is not divisible by Bk (%\" PRIu32 \")\",\n            k, Bk);\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t grid_y = math_ceil_div(m, Bm);\n    const size_t grid_x = n / Bn;\n    const size_t grid_z = num_experts;\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_mf4w_moe_dense_matmul_fn,\n        threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,\n        grid_x, grid_y, grid_z,\n        sizeof(args), &args,\n        6,\n        (const struct gptoss_metal_buffer *[]) {expert_offsets_buffer, input_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},\n        (const size_t[]) {expert_offsets_offset, input_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_accumulate_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* expert_buffer,\n    size_t expert_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_channels,\n    uint32_t num_tokens,\n    uint32_t num_experts)\n{\n    if (command_buffer->object == NULL || f32_accumulate_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_channels% 4 != 0) {\n        return gptoss_status_invalid_argument;\n    }\n\n    if (threadgroup_size == 0) {\n        threadgroup_size = f32_accumulate_fn->max_threadgroup_threads;\n    } else if (threadgroup_size > f32_accumulate_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t num_vecs = num_channels / 4;\n    const size_t num_vecs_per_expert = num_vecs * num_tokens;\n    const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;\n    const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));\n    const struct gptoss_accumulate_args args = {\n        .num_vecs_per_expert = num_vecs_per_expert,\n        .num_vecs_per_threadgroup = num_vecs_per_threadgroup,\n        .num_vecs = num_vecs,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_accumulate_fn,\n        threadgroup_size, 1, 1,\n        num_threadgroups, num_tokens, 1,\n        sizeof(args), &args,\n        4,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer, control_buffer},\n        (const size_t[]) {input_offset, expert_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_topk_fn,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_tokens,\n    uint32_t num_experts,\n    uint32_t num_active_experts)\n{\n    if (command_buffer->object == NULL || f32_topk_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_experts != 32  && num_experts != 128) {\n        return gptoss_status_invalid_argument;\n    }\n\n    if (num_active_experts != 4) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_topk_args args = { 0 };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_topk_fn,\n        /*threadgroup_size=*/32, 1, 1,\n        num_tokens, 1, 1,\n        sizeof(args), &args,\n        3,\n        (const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer, control_buffer},\n        (const size_t[]) {input_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_sdpa_fn,\n    const struct gptoss_metal_buffer* q_buffer,\n    size_t q_offset,\n    const struct gptoss_metal_buffer* kv_buffer,\n    size_t kv_offset,\n    const struct gptoss_metal_buffer* s_buffer,\n    size_t s_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t window,\n    uint32_t kv_stride,\n    uint32_t num_q_tokens,\n    uint32_t num_kv_tokens,\n    uint32_t num_q_heads,\n    uint32_t num_kv_heads,\n    uint32_t head_dim)\n{\n    if (command_buffer->object == NULL || f32_sdpa_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (num_q_heads != num_kv_heads * 8) {\n        GPTOSS_LOG_ERROR(\"number of Q heads (%\" PRIu32 \") must be 8 times the number of KV heads (%\" PRIu32 \")\",\n            num_q_heads, num_kv_heads);\n        return gptoss_status_invalid_argument;\n    }\n\n    if (head_dim != 64) {\n        GPTOSS_LOG_ERROR(\"attention head dimension (%\" PRIu32 \") must be 64\", head_dim);\n        return gptoss_status_invalid_argument;\n    }\n\n    const size_t max_context_tokens = math_min(num_q_tokens + num_kv_tokens + 1, window);\n    const size_t threadgroup_size = math_min(f32_sdpa_fn->max_threadgroup_threads,\n        max_context_tokens * f32_sdpa_fn->simdgroup_threads);\n    const size_t half_threadgroup_size = math_round_down_po2(threadgroup_size / 2, f32_sdpa_fn->simdgroup_threads);\n\n    const struct gptoss_sdpa_args args = {\n        .qkv_dim = head_dim * (num_q_heads + 2 * num_kv_heads),\n        .num_kv_tokens = num_kv_tokens,\n        .kv_stride = kv_stride,\n        .window = window,\n    };\n\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_sdpa_fn,\n        threadgroup_size, 1, 1,\n        num_q_tokens, num_kv_heads, 1,\n        sizeof(args), &args,\n        5,\n        (const struct gptoss_metal_buffer *[]) {q_buffer, kv_buffer, s_buffer, output_buffer, control_buffer},\n        (const size_t[]) {q_offset, kv_offset, s_offset, output_offset, control_offset},\n        /*threadgroup_buffer_size=*/half_threadgroup_size * 8 * 4 * sizeof(float));\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_softmax_fn,\n    size_t threadgroup_size,\n    size_t max_threadgroups,\n    const struct gptoss_metal_buffer* score_buffer,\n    size_t score_offset,\n    const struct gptoss_metal_buffer* argmax_buffer,\n    size_t argmax_offset,\n    const struct gptoss_metal_buffer* prob_buffer,\n    size_t prob_offset,\n    const struct gptoss_metal_buffer* sum_buffer,\n    size_t sum_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint32_t num_channels,\n    uint32_t num_tokens,\n    float temperature,\n    uint32_t* num_threadgroups_out,\n    uint32_t* num_channels_per_threadgroup_out)\n{\n    *num_threadgroups_out = 0;\n    *num_channels_per_threadgroup_out = 0;\n    if (command_buffer->object == NULL || f32_softmax_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    const size_t num_vecs = num_channels;\n    const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;\n    const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));\n    const struct gptoss_softmax_args args = {\n        .num_vecs = num_vecs,\n        .num_vecs_per_threadgroup = num_vecs_per_threadgroup,\n        .max_threadgroups = max_threadgroups,\n        .temperature = temperature,\n    };\n\n    *num_threadgroups_out = num_threadgroups;\n    *num_channels_per_threadgroup_out = num_vecs_per_threadgroup;\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_softmax_fn,\n        threadgroup_size, 1, 1,\n        num_threadgroups, num_tokens, 1,\n        sizeof(args), &args,\n        5,\n        (const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer, control_buffer},\n        (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* f32_sample_fn,\n    size_t min_threadgroup_size,\n    const struct gptoss_metal_buffer* prob_buffer,\n    size_t prob_offset,\n    const struct gptoss_metal_buffer* sum_buffer,\n    size_t sum_offset,\n    const struct gptoss_metal_buffer* token_buffer,\n    size_t token_offset,\n    const struct gptoss_metal_buffer* control_buffer,\n    size_t control_offset,\n    uint64_t rng_seed,\n    uint32_t rng_offset,\n    uint32_t num_blocks,\n    uint32_t num_channels,\n    uint32_t num_channels_per_block)\n{\n    if (command_buffer->object == NULL || f32_sample_fn->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    if (min_threadgroup_size > f32_sample_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n\n    if (min_threadgroup_size % f32_sample_fn->simdgroup_threads != 0) {\n        return gptoss_status_invalid_argument;\n    }\n\n    if (num_blocks > f32_sample_fn->max_threadgroup_threads) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const struct gptoss_sample_args args = {\n        .rng_seed = rng_seed,\n        .rng_offset = rng_offset,\n        .num_blocks = num_blocks,\n        .num_dims = num_channels,\n        .num_dims_per_block = num_channels_per_block,\n    };\n\n    const size_t threadgroup_size = math_max(min_threadgroup_size,\n        math_round_up_po2(num_blocks, f32_sample_fn->simdgroup_threads));\n    return gptoss_metal_command_buffer_encode_launch_kernel(\n        command_buffer, f32_sample_fn,\n        threadgroup_size, 1, 1,\n        1, 1, 1,\n        sizeof(args), &args,\n        4,\n        (const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, token_buffer, control_buffer},\n        (const size_t[]) {prob_offset, sum_offset, token_offset, control_offset},\n        /*threadgroup_buffer_size=*/0);\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/metal.m",
    "content": "#import <Foundation/Foundation.h>\n#import <Metal/Metal.h>\n\n#include <dispatch/dispatch.h>\n#include <mach-o/getsect.h>\n\n#include <gpt-oss/types.h>\n\n#include <internal/log.h>\n#include <internal/metal.h>\n\n\nstatic size_t gptoss_metal_device_get_core_count(id<MTLDevice> device) {\n    if (!device) {\n        return 0;\n    }\n\n    const uint64_t target_registry_id = [device registryID];\n\n    io_iterator_t it = IO_OBJECT_NULL;\n    const kern_return_t kr = IOServiceGetMatchingServices(\n        kIOMainPortDefault,\n        IOServiceMatching(\"IOAccelerator\"),\n        &it\n    );\n    if (kr != KERN_SUCCESS) {\n        GPTOSS_LOG_ERROR(\"failed to find IOAccelerator objects: error %d\", kr);\n        return 0;\n    }\n\n    size_t result = 0;\n    for (io_object_t obj = IOIteratorNext(it); obj != IO_OBJECT_NULL; obj = IOIteratorNext(it)) {\n        uint64_t registry_id = 0;\n        if (IORegistryEntryGetRegistryEntryID(obj, &registry_id) == KERN_SUCCESS &&\n            registry_id == target_registry_id)\n        {\n            // Read \"gpu-core-count\" from this accelerator node\n            const CFTypeRef value = IORegistryEntryCreateCFProperty(\n                obj, CFSTR(\"gpu-core-count\"), kCFAllocatorDefault, 0);\n            if (value != NULL) {\n                if (CFGetTypeID(value) == CFNumberGetTypeID()) {\n                    int32_t n = -1;\n                    if (CFNumberGetValue((CFNumberRef) value, kCFNumberSInt32Type, &n) && n > 0) {\n                        result = (size_t) n;\n                    }\n                }\n                CFRelease(value);\n            }\n            IOObjectRelease(obj);\n            break;\n        }\n        IOObjectRelease(obj);\n    }\n\n    IOObjectRelease(it);\n    return result;\n}\n\nenum gptoss_status gptoss_metal_device_create_system_default(\n    struct gptoss_metal_device* device_out)\n{\n    id<MTLDevice> device_obj = MTLCreateSystemDefaultDevice();\n    if (device_obj == nil) {\n        GPTOSS_LOG_ERROR(\"failed to create Metal device\");\n        return gptoss_status_unsupported_system;\n    }\n\n    device_out->object = (void*) device_obj;\n    device_out->num_cores = gptoss_metal_device_get_core_count(device_obj);\n    device_out->max_buffer_size = (size_t) [device_obj maxBufferLength];\n    device_out->max_threadgroup_memory = (size_t) [device_obj maxThreadgroupMemoryLength];\n    const MTLSize max_threadgroup_threads = [device_obj maxThreadsPerThreadgroup];\n    device_out->max_threadgroup_threads_x = (size_t) max_threadgroup_threads.width;\n    device_out->max_threadgroup_threads_y = (size_t) max_threadgroup_threads.height;\n    device_out->max_threadgroup_threads_z = (size_t) max_threadgroup_threads.depth;\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_device_release(\n    struct gptoss_metal_device* device)\n{\n    if (device->object != NULL) {\n        id<MTLDevice> device_obj = (id<MTLDevice>) device->object;\n        [device_obj release];\n    }\n    memset(device, 0, sizeof(struct gptoss_metal_device));\n    return gptoss_status_success;\n}\n\nextern const struct mach_header_64 __dso_handle;\n\nenum gptoss_status gptoss_metal_library_create_default(\n    const struct gptoss_metal_device* device,\n    struct gptoss_metal_library* library_out)\n{\n    enum gptoss_status status = gptoss_status_success;\n    id<MTLDevice> device_obj = (id<MTLDevice>) device->object;\n    id<MTLLibrary> library_obj = nil;\n    NSAutoreleasePool* autorelease_pool = nil;\n    dispatch_data_t library_blob = NULL;\n\n    unsigned long library_size = 0;\n    uint8_t* library_data = getsectiondata(&__dso_handle, \"__METAL\", \"__shaders\", &library_size);\n    if (library_data != NULL) {\n        library_blob = dispatch_data_create(library_data, library_size, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);\n\n        autorelease_pool = [[NSAutoreleasePool alloc] init];\n        NSError* error_obj = nil;\n        library_obj = [device_obj newLibraryWithData:library_blob error:&error_obj];\n        if (library_obj == nil) {\n            GPTOSS_LOG_ERROR(\"failed to create Metal library: %s\", [[error_obj localizedDescription] UTF8String]);\n            status = gptoss_status_unsupported_system;\n            goto cleanup;\n        }\n    } else {\n        // Fall-back to loading from the bundle\n        library_obj = [device_obj newDefaultLibrary];\n        if (library_obj == nil) {\n            GPTOSS_LOG_ERROR(\"failed to create Metal default library\");\n            status = gptoss_status_unsupported_system;\n            goto cleanup;\n        }\n    }\n\n    *library_out = (struct gptoss_metal_library) {\n        .object = (void*) library_obj,\n    };\n\ncleanup:\n    if (library_blob != NULL) {\n        dispatch_release(library_blob);\n    }\n    if (autorelease_pool != nil) {\n        [autorelease_pool drain];\n    }\n    return status;\n}\n\nenum gptoss_status gptoss_metal_library_release(\n    struct gptoss_metal_library* library)\n{\n    if (library->object != NULL) {\n        id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;\n        [library_obj release];\n    }\n    memset(library, 0, sizeof(struct gptoss_metal_library));\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_function_create(\n    const struct gptoss_metal_library* library,\n    const char* name,\n    struct gptoss_metal_function* function_out)\n{\n    __block NSString* error_string_obj = nil;\n    id<MTLFunction> function_obj = nil;\n    MTLComputePipelineDescriptor* pipeline_descriptor_obj = nil;\n    __block id<MTLComputePipelineState> pipeline_state_obj = nil;\n    dispatch_semaphore_t pipeline_build_semaphore = NULL;\n    enum gptoss_status status = gptoss_status_success;\n\n    NSAutoreleasePool* autorelease_pool = [[NSAutoreleasePool alloc] init];\n    id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;\n    NSString* name_obj = [NSString stringWithUTF8String:name];\n    function_obj = [library_obj newFunctionWithName:name_obj];\n    if (function_obj == nil) {\n        GPTOSS_LOG_ERROR(\"failed to create Metal function %s\", name);\n        status = gptoss_status_unsupported_system;\n        goto cleanup;\n    }\n    id<MTLDevice> device_obj = [library_obj device];\n    pipeline_descriptor_obj = [[MTLComputePipelineDescriptor alloc] init];\n    [pipeline_descriptor_obj setComputeFunction:function_obj];\n    [pipeline_descriptor_obj setThreadGroupSizeIsMultipleOfThreadExecutionWidth:YES];\n\n    pipeline_build_semaphore = dispatch_semaphore_create(/*value=*/0);\n    [device_obj newComputePipelineStateWithDescriptor:pipeline_descriptor_obj\n                                              options:MTLPipelineOptionNone\n                                    completionHandler:^(id<MTLComputePipelineState> _Nullable new_state,\n                                                        MTLComputePipelineReflection* _Nullable reflection,\n                                                        NSError* _Nullable error_obj) {\n        if (new_state != nil) {\n            pipeline_state_obj = [new_state retain];\n        }\n        if (error_obj != nil) {\n            error_string_obj = [[error_obj localizedDescription] copy];\n        }\n        dispatch_semaphore_signal(pipeline_build_semaphore);\n    }];\n    dispatch_semaphore_wait(pipeline_build_semaphore, DISPATCH_TIME_FOREVER);\n\n    if (pipeline_state_obj == nil) {\n        const char* error_string = \"unknown error\";\n        if (error_string_obj != nil) {\n            error_string = [error_string_obj UTF8String];\n        }\n        GPTOSS_LOG_ERROR(\"failed to create Metal compute pipeline state for function %s: %s\",\n            name, error_string);\n        status = gptoss_status_unsupported_system;\n        goto cleanup;\n    }\n\n    // Commit\n    function_out->function_object = function_obj;\n    function_out->pipeline_state_object = pipeline_state_obj;\n    function_out->max_threadgroup_threads = (size_t) [pipeline_state_obj maxTotalThreadsPerThreadgroup];\n    function_out->simdgroup_threads = (size_t) [pipeline_state_obj threadExecutionWidth];\n    function_out->static_threadgroup_memory = (size_t) [pipeline_state_obj staticThreadgroupMemoryLength];\n\n    function_obj = nil;\n    pipeline_state_obj = nil;\n\ncleanup:\n    if (function_obj != nil) {\n        [function_obj release];\n    }\n    if (pipeline_descriptor_obj != nil) {\n        [pipeline_descriptor_obj release];\n    }\n    if (error_string_obj != nil) {\n        [error_string_obj release];\n    }\n    if (pipeline_build_semaphore != NULL) {\n        dispatch_release(pipeline_build_semaphore);\n    }\n    if (autorelease_pool != nil) {\n        [autorelease_pool drain];\n    }\n    return status;\n}\n\nenum gptoss_status gptoss_metal_function_release(\n    struct gptoss_metal_function* function)\n{\n    if (function->pipeline_state_object != NULL) {\n        id<MTLComputePipelineState> pipeline_state_obj = (id<MTLComputePipelineState>) function->pipeline_state_object;\n        [pipeline_state_obj release];\n    }\n    if (function->function_object != NULL) {\n        id<MTLFunction> function_obj = (id<MTLFunction>) function->function_object;\n        [function_obj release];\n    }\n    memset(function, 0, sizeof(struct gptoss_metal_function));\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_buffer_create(\n    const struct gptoss_metal_device* device,\n    size_t size,\n    const void* data,\n    struct gptoss_metal_buffer* buffer_out)\n{\n    id<MTLDevice> device_obj = (id<MTLDevice>) device->object;\n    id<MTLBuffer> buffer_obj = nil;\n    if (data != NULL) {\n        buffer_obj = [device_obj newBufferWithBytes:data length:size options:MTLResourceStorageModeShared];\n    } else {\n        buffer_obj = [device_obj newBufferWithLength:size options:MTLResourceStorageModeShared];\n    }\n    if (buffer_obj == nil) {\n        GPTOSS_LOG_ERROR(\"failed to create Metal buffer of size %zu\", size);\n        return gptoss_status_unsupported_system;\n    }\n    buffer_out->object = (void*) buffer_obj;\n    buffer_out->size = size;\n    buffer_out->ptr = [buffer_obj contents];\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_buffer_wrap(\n    const struct gptoss_metal_device* device,\n    size_t size,\n    const void* data,\n    struct gptoss_metal_buffer* buffer_out)\n{\n    id<MTLDevice> device_obj = (id<MTLDevice>) device->object;\n    id<MTLBuffer> buffer_obj = [device_obj newBufferWithBytesNoCopy:(void*) data length:size options:MTLResourceStorageModeShared deallocator:nil];\n    if (buffer_obj == nil) {\n        GPTOSS_LOG_ERROR(\"failed to wrap Metal buffer of size %zu\", size);\n        return gptoss_status_unsupported_system;\n    }\n    buffer_out->object = (void*) buffer_obj;\n    buffer_out->size = size;\n    buffer_out->ptr = (void*) data;\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_buffer_release(\n    struct gptoss_metal_buffer* buffer)\n{\n    if (buffer->object != NULL) {\n        id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffer->object;\n        [buffer_obj release];\n    }\n    memset(buffer, 0, sizeof(struct gptoss_metal_buffer));\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_queue_create(\n    const struct gptoss_metal_device* device,\n    struct gptoss_metal_command_queue* command_queue_out)\n{\n    id<MTLDevice> device_obj = (id<MTLDevice>) device->object;\n    id<MTLCommandQueue> command_queue_obj = [device_obj newCommandQueue];\n    if (command_queue_obj == nil) {\n        GPTOSS_LOG_ERROR(\"failed to create Metal command queue\");\n        return gptoss_status_unsupported_system;\n    }\n    command_queue_out->object = (void*) command_queue_obj;\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_queue_release(\n    struct gptoss_metal_command_queue* command_queue)\n{\n    if (command_queue->object != NULL) {\n        id<MTLCommandQueue> command_queue_obj = (id<MTLCommandQueue>) command_queue->object;\n        [command_queue_obj release];\n    }\n    memset(command_queue, 0, sizeof(struct gptoss_metal_command_queue));\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_buffer_create(\n    const struct gptoss_metal_command_queue* command_queue,\n    struct gptoss_metal_command_buffer* command_buffer_out)\n{\n    id<MTLCommandQueue> command_queue_obj = (id<MTLCommandQueue>) command_queue->object;\n    id<MTLCommandBuffer> command_buffer_obj = [command_queue_obj commandBuffer];\n    if (command_buffer_obj == nil) {\n        GPTOSS_LOG_ERROR(\"failed to create Metal command buffer\");\n        return gptoss_status_unsupported_system;\n    }\n    [command_buffer_obj retain];\n    command_buffer_out->object = (void*) command_buffer_obj;\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_fill_buffer(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_buffer* buffer,\n    size_t offset,\n    size_t size,\n    uint8_t fill_value)\n{\n    if (command_buffer->object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n    if (buffer->object == NULL) {\n        return gptoss_status_invalid_argument;\n    }\n\n    id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;\n    id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffer->object;\n\n    id<MTLBlitCommandEncoder> command_encoder_obj = [command_buffer_obj blitCommandEncoder];\n\n    const NSRange range = NSMakeRange((NSUInteger) offset, (NSUInteger) size);\n    [command_encoder_obj fillBuffer:buffer_obj range:range value:fill_value];\n    [command_encoder_obj endEncoding];\n\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_copy_buffer(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_buffer* input_buffer,\n    size_t input_offset,\n    const struct gptoss_metal_buffer* output_buffer,\n    size_t output_offset,\n    size_t size)\n{\n    if (command_buffer->object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n    if (input_buffer->object == NULL) {\n        return gptoss_status_invalid_argument;\n    }\n    if (output_buffer->object == NULL) {\n        return gptoss_status_invalid_argument;\n    }\n\n    id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;\n    id<MTLBuffer> input_buffer_obj = (id<MTLBuffer>) input_buffer->object;\n    id<MTLBuffer> output_buffer_obj = (id<MTLBuffer>) output_buffer->object;\n\n    id<MTLBlitCommandEncoder> command_encoder_obj = [command_buffer_obj blitCommandEncoder];\n\n    [command_encoder_obj copyFromBuffer:input_buffer_obj sourceOffset:(NSUInteger) input_offset\n                         toBuffer:output_buffer_obj destinationOffset:(NSUInteger) output_offset\n                         size:(NSUInteger) size];\n    [command_encoder_obj endEncoding];\n\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    const struct gptoss_metal_function* function,\n    size_t threadgroup_size_x,\n    size_t threadgroup_size_y,\n    size_t threadgroup_size_z,\n    size_t num_threadgroups_x,\n    size_t num_threadgroups_y,\n    size_t num_threadgroups_z,\n    size_t params_size,\n    const void* params,\n    size_t num_device_buffers,\n    const struct gptoss_metal_buffer** device_buffers,\n    const size_t* device_buffer_offsets,\n    size_t threadgroup_buffer_size)\n{\n    if (command_buffer->object == NULL || function->pipeline_state_object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;\n    id<MTLComputePipelineState> pipeline_state_obj = (id<MTLComputePipelineState>) function->pipeline_state_object;\n\n    id<MTLComputeCommandEncoder> command_encoder_obj = [command_buffer_obj computeCommandEncoder];\n\n    // Set kernel arguments\n    [command_encoder_obj setComputePipelineState:pipeline_state_obj];\n    [command_encoder_obj setBytes:params length:params_size atIndex:0];\n    for (size_t i = 0; i < num_device_buffers; ++i) {\n        id<MTLBuffer> buffer_obj = (id<MTLBuffer>) device_buffers[i]->object;\n        const NSUInteger offset = device_buffer_offsets == NULL ? 0 : (NSUInteger) device_buffer_offsets[i];\n        [command_encoder_obj setBuffer:buffer_obj offset:offset atIndex:i + 1];\n    }\n    if (threadgroup_buffer_size != 0) {\n        [command_encoder_obj setThreadgroupMemoryLength:threadgroup_buffer_size atIndex:0];\n    }\n\n    // Dispatch kernel\n    const MTLSize threadgroup_size = MTLSizeMake(threadgroup_size_x, threadgroup_size_y, threadgroup_size_z);\n    const MTLSize num_threadgroups = MTLSizeMake(num_threadgroups_x, num_threadgroups_y, num_threadgroups_z);\n    [command_encoder_obj dispatchThreadgroups:num_threadgroups threadsPerThreadgroup:threadgroup_size];\n    [command_encoder_obj endEncoding];\n\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_buffer_commit(\n    const struct gptoss_metal_command_buffer* command_buffer)\n{\n    if (command_buffer->object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;\n    [command_buffer_obj commit];\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_buffer_wait_completion(\n    const struct gptoss_metal_command_buffer* command_buffer,\n    double* elapsed_seconds)\n{\n    if (command_buffer->object == NULL) {\n        return gptoss_status_invalid_state;\n    }\n\n    id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;\n    [command_buffer_obj waitUntilCompleted];\n    if (elapsed_seconds != NULL) {\n        const CFTimeInterval start_time = [command_buffer_obj GPUStartTime];\n        const CFTimeInterval end_time = [command_buffer_obj GPUEndTime];\n        *elapsed_seconds = (double) end_time - (double) start_time;\n    }\n    return gptoss_status_success;\n}\n\nenum gptoss_status gptoss_metal_command_buffer_release(\n    struct gptoss_metal_command_buffer* command_buffer)\n{\n    if (command_buffer->object != NULL) {\n        id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;\n        [command_buffer_obj release];\n    }\n    memset(command_buffer, 0, sizeof(struct gptoss_metal_command_buffer));\n    return gptoss_status_success;\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/model.c",
    "content": "#include <assert.h>\n#include <inttypes.h>\n#include <stdatomic.h>\n#include <stdint.h>\n#include <stdlib.h>\n#include <string.h>\n\n#include <errno.h>  // errno, EISDIR, ENOENT, ENOTDIR\n#include <fcntl.h>  // open\n#include <mach/vm_page_size.h>  // vm_page_size\n#include <sys/mman.h>  // mmap, PROT_READ, MAP_PRIVATE\n#include <sys/stat.h>  // fstat, stat\n#include <sys/types.h>  // off_t, ssize_t\n#include <unistd.h>  // close\n\n#include <gpt-oss.h>\n\n#include \"internal/datatype.h\"\n#include \"internal/kernel-args.h\"  // gptoss_expert_prediction\n#include \"internal/log.h\"\n#include \"internal/uuid.h\"\n#include \"internal/storage.h\"\n#include \"internal/math.h\"\n#include \"internal/model.h\"\n\n\nstatic size_t round_up_to_page_size(size_t bytes) {\n    const size_t page_size_mask = (size_t) vm_page_size - 1;\n    if ((bytes & page_size_mask) != 0) {\n        bytes |= page_size_mask;\n        bytes += 1;\n    }\n    return bytes;\n}\n\nstatic size_t round_down_to_page_size(size_t bytes) {\n    const size_t page_size_mask = (size_t) vm_page_size - 1;\n    return bytes & ~page_size_mask;\n}\n\nstatic enum gptoss_status read_fd(int fd, void* data, size_t size, const char* path) {\n    assert(fd != -1);\n    assert(data != NULL);\n    assert(size != 0);\n\n    size_t bytes_to_read = size;\n    char* current_byte = (char*) data;\n    do {\n        const ssize_t read_result = read(fd, current_byte, bytes_to_read);\n        if (read_result < 0) {\n            GPTOSS_LOG_ERROR(\"reading %zu bytes from file %s failed with error %d\",\n                size, path, errno);\n            return gptoss_status_io_error;\n        }\n        current_byte += (size_t) read_result;\n        bytes_to_read -= (size_t) read_result;\n    } while (bytes_to_read != 0);\n    return gptoss_status_success;\n}\n\nstatic void prefetch_fd(int fd, size_t offset, size_t size, const char* path) {\n    // radvisory.ra_count is int, so we can't prefetch 2GB+ at once\n    const size_t prefetch_max = round_down_to_page_size((size_t) INT_MAX);\n    do {\n        const size_t prefetch_size = math_min(size, prefetch_max);\n        const struct radvisory ra = {\n            .ra_offset = offset,\n            .ra_count = (int) prefetch_size,\n        };\n        if (fcntl(fd, F_RDADVISE, &ra) == -1) {\n            GPTOSS_LOG_WARNING(\"fcntl(%s, F_RDADVISE, .ra_offset=%zu, .ra_count=%d) failed with error %d\\n\",\n                path, (size_t) ra.ra_offset, ra.ra_count, errno);\n            return;\n        }\n        offset += prefetch_size;\n        size -= prefetch_size;\n    } while (size != 0);\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(\n    const char* path,\n    gptoss_model_t* model_out)\n{\n    *model_out = NULL;\n\n    enum gptoss_status status = gptoss_status_success;\n    struct gptoss_model* model = NULL;\n    struct gptoss_tokenizer* tokenizer = NULL;\n    int fd = -1;\n    size_t file_offset = 0;\n\n    fd = open(path, O_RDONLY);\n    if (fd == -1) {\n        GPTOSS_LOG_ERROR(\"open(%s) failed with error %d\", path, errno);\n        switch (errno) {\n            case EISDIR:\n            case ENOENT:\n            case ENOTDIR:\n                status = gptoss_status_invalid_argument;\n                break;\n            default:\n                status = gptoss_status_io_error;\n                break;\n        }\n        goto cleanup;\n    }\n\n    struct gptoss_file_header file_header;\n    status = read_fd(fd, &file_header, sizeof(file_header), path);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    file_offset += sizeof(file_header);\n\n    if (file_header.magic[0] != 'G' ||\n        file_header.magic[1] != 'P' ||\n        file_header.magic[2] != 'T' ||\n        file_header.magic[3] != '-' ||\n        file_header.magic[4] != 'O' ||\n        file_header.magic[5] != 'S' ||\n        file_header.magic[6] != 'S' ||\n        file_header.magic[7] != ' ' ||\n        file_header.magic[8] != 'v' ||\n        file_header.magic[9] != '1' ||\n        file_header.magic[10] != '.' ||\n        file_header.magic[11] != '0' ||\n        file_header.zero != 0)\n    {\n        GPTOSS_LOG_ERROR(\"invalid magic in file %s\", path);\n        status = gptoss_status_invalid_argument;\n        goto cleanup;\n    }\n\n    struct gptoss_uuid model_uuid;\n    status = read_fd(fd, &model_uuid, sizeof(model_uuid), path);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    file_offset += sizeof(model_uuid);\n\n    if (!gptoss_is_gptoss_model_uuid(&model_uuid)) {\n        GPTOSS_LOG_ERROR(\"unsupported model UUID \" UUID_FORMAT, UUID_ARGS(model_uuid));\n        status = gptoss_status_invalid_argument;\n        goto cleanup;\n    }\n\n    struct gptoss_gptoss_model_header model_header;\n    status = read_fd(fd, &model_header, sizeof(model_header), path);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    file_offset += sizeof(model_header);\n\n    struct gptoss_uuid layout_uuid;\n    status = read_fd(fd, &layout_uuid, sizeof(layout_uuid), path);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    file_offset += sizeof(layout_uuid);\n\n    if (!gptoss_is_applegpu_layout_uuid(&layout_uuid)) {\n        GPTOSS_LOG_ERROR(\"unsupported layout UUID \" UUID_FORMAT, UUID_ARGS(layout_uuid));\n        status = gptoss_status_invalid_argument;\n        goto cleanup;\n    }\n\n    const size_t model_size = sizeof(struct gptoss_model) + model_header.num_blocks * sizeof(struct gptoss_metal_buffer);\n    model = malloc(model_size);\n    if (model == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to allocate %zu bytes for model descriptor\", model_size);\n        status = gptoss_status_insufficient_memory;\n        goto cleanup;\n    }\n    memset(model, 0, model_size);\n\n    atomic_store_explicit(&model->ref_count, 1, memory_order_relaxed);\n    model->context_length = model_header.context_length;\n    model->num_blocks = model_header.num_blocks;\n    model->num_experts = model_header.num_experts;\n    model->num_active_experts = model_header.num_active_experts;\n    model->embedding_dim = model_header.embedding_dim;\n    model->mlp_dim = model_header.mlp_dim;\n    model->swiglu_limit = model_header.swiglu_limit;\n    model->head_dim = model_header.head_dim;\n    model->num_heads = model_header.num_heads;\n    model->num_kv_heads = model_header.num_kv_heads;\n    model->attention_window = model_header.attention_window;\n    model->rope_theta = model_header.rope_theta;\n    model->interpolation_scale = model_header.interpolation_scale;\n    model->yarn_offset = model_header.yarn_offset;\n    model->yarn_scale = model_header.yarn_scale;\n    model->yarn_multiplier = model_header.yarn_multiplier;\n    model->rmsnorm_epsilon = model_header.rmsnorm_epsilon;\n\n    struct gptoss_uuid tokenizer_uuid;\n    status = read_fd(fd, &tokenizer_uuid, sizeof(tokenizer_uuid), path);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    file_offset += sizeof(tokenizer_uuid);\n\n    if (!gptoss_is_tiktoken_tokenizer_uuid(&tokenizer_uuid)) {\n        GPTOSS_LOG_ERROR(\"unsupported tokenizer UUID \" UUID_FORMAT, UUID_ARGS(tokenizer_uuid));\n        status = gptoss_status_invalid_argument;\n        goto cleanup;\n    }\n\n    struct gptoss_tiktoken_tokenizer_header tokenizer_header;\n    status = read_fd(fd, &tokenizer_header, sizeof(tokenizer_header), path);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    file_offset += sizeof(tokenizer_header);\n\n    tokenizer = malloc(sizeof(struct gptoss_tokenizer));\n    if (tokenizer == NULL) {\n        GPTOSS_LOG_ERROR(\"failed to allocate %zu bytes for tokenizer descriptor\", sizeof(struct gptoss_tokenizer));\n        status = gptoss_status_insufficient_memory;\n        goto cleanup;\n    }\n    memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));\n    // Initialize all special token IDs to UINT32_MAX (0xFF in all bytes)\n    memset(tokenizer->special_token_id, 0xFF, sizeof(tokenizer->special_token_id));\n\n    atomic_store_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);\n    tokenizer->num_special_tokens = tokenizer_header.num_special_tokens;\n    tokenizer->num_text_tokens = tokenizer_header.num_text_tokens;\n    model->vocabulary_size = tokenizer_header.num_special_tokens + tokenizer_header.num_text_tokens;\n    for (uint32_t t = 0; t < tokenizer_header.num_special_tokens; t++) {\n        struct gptoss_uuid token_uuid;\n        status = read_fd(fd, &token_uuid, sizeof(token_uuid), path);\n        if (status != gptoss_status_success) {\n            goto cleanup;\n        }\n        file_offset += sizeof(token_uuid);\n\n        const enum gptoss_special_token token = gptoss_special_token_decode_uuid(&token_uuid);\n        if (token != gptoss_special_token_invalid) {\n            tokenizer->special_token_id[token - 1] = tokenizer_header.num_text_tokens + t;\n        }\n    }\n\n    const size_t tokenizer_start_offset = file_offset;\n    const size_t tokenizer_end_offset = tokenizer_start_offset + tokenizer_header.regex_size + tokenizer_header.tokens_size;\n    const size_t tokenizer_mapping_start = round_down_to_page_size(tokenizer_start_offset);\n    const size_t tokenizer_mapping_size = round_up_to_page_size(tokenizer_end_offset) - tokenizer_mapping_start;\n    void* tokenizer_mapping_ptr = mmap(NULL, tokenizer_mapping_size, PROT_READ, MAP_PRIVATE, fd, tokenizer_mapping_start);\n    if (tokenizer_mapping_ptr == (void*) -1) {\n        GPTOSS_LOG_ERROR(\"failed to mmap(%s) tokenizer at offset %zu size %zu\",\n            path, tokenizer_mapping_start, tokenizer_mapping_size);\n        status = gptoss_status_io_error;\n        goto cleanup;\n    }\n    tokenizer->mapping_ptr = tokenizer_mapping_ptr;\n    tokenizer->mapping_size = tokenizer_mapping_size;\n    tokenizer->regex_ptr = (const char*) tokenizer_mapping_ptr + (tokenizer_start_offset - tokenizer_mapping_start);\n    tokenizer->tokens_ptr = tokenizer->regex_ptr + tokenizer_header.regex_size;\n\n    if (madvise(tokenizer_mapping_ptr, tokenizer_mapping_size, MADV_RANDOM | MADV_WILLNEED) != 0) {\n        GPTOSS_LOG_WARNING(\"madvise(%s, size=%zu) failed with error %d\", path, tokenizer_mapping_size, errno);\n    }\n\n    prefetch_fd(fd, tokenizer_mapping_start, tokenizer_mapping_size, path);\n\n    struct stat model_stat = {0};\n    int stat_result = fstat(fd, &model_stat);\n    if (stat_result != 0) {\n        GPTOSS_LOG_ERROR(\"stat(%s) failed with error %d\", path, errno);\n        status = gptoss_status_io_error;\n        goto cleanup;\n    }\n\n    const size_t model_mapping_start = round_up_to_page_size(tokenizer_end_offset);\n    const size_t model_mapping_size = round_up_to_page_size((size_t) model_stat.st_size) - model_mapping_start;\n    void* model_mapping_ptr = mmap(NULL, model_mapping_size, PROT_READ, MAP_PRIVATE, fd, model_mapping_start);\n    if (model_mapping_ptr == (void*) -1) {\n        GPTOSS_LOG_ERROR(\"failed to mmap(%s) model weights at offset %zu size %zu\",\n            path, model_mapping_start, model_mapping_size);\n        status = gptoss_status_io_error;\n        goto cleanup;\n    }\n    model->mapping_ptr = model_mapping_ptr;\n    model->mapping_size = model_mapping_size;\n\n    if (madvise(model_mapping_ptr, model_mapping_size, MADV_SEQUENTIAL | MADV_WILLNEED) != 0) {\n        GPTOSS_LOG_WARNING(\"madvise(%s, size=%zu) failed with error %d\", path, model_mapping_size, errno);\n    }\n\n    prefetch_fd(fd, model_mapping_start, model_mapping_size, path);\n\n    if (mlock(model_mapping_ptr, model_mapping_size) != 0) {\n        GPTOSS_LOG_WARNING(\"mlock(%s, size=%zu) failed with error %d\", path, model_mapping_size, errno);\n    } else {\n        model->lock_memory = true;\n    }\n\n    // Initialize Metal\n    status = gptoss_metal_device_create_system_default(&model->device);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    model->max_threadgroups = model->device.num_cores * 3;\n    status = gptoss_metal_command_queue_create(&model->device, &model->command_queue);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n\n    // Metal kernels\n    status = gptoss_metal_library_create_default(&model->device, &model->library);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_bf16_f32_embeddings\", &model->bf16_f32_embeddings_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_bf16w_rmsnorm\", &model->f32_bf16w_rmsnorm_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_bf16w_matmul\", &model->f32_bf16w_matmul_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_bf16w_matmul_qkv\", &model->f32_bf16w_matmul_qkv_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_bf16w_dense_matmul_qkv\", &model->f32_bf16w_dense_matmul_qkv_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_bf16w_dense_matmul_attn_output\", &model->f32_bf16w_dense_matmul_attn_output_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_bf16w_dense_matmul_mlp_gate\", &model->f32_bf16w_dense_matmul_mlp_gate_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_bf16w_unembedding\", &model->f32_bf16w_unembedding_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_rope\", &model->f32_rope_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_expert_routing_metadata\", &model->f32_expert_routing_metadata_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_scatter_e4\", &model->f32_scatter_e4_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_mf4w_moe_dense_matmul_swiglu\", &model->f32_mf4w_moe_dense_matmul_swiglu_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_mf4w_moe_dense_matmul\", &model->f32_mf4w_moe_dense_matmul_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_gather_and_accumulate_e4\", &model->f32_gather_and_accumulate_e4_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_mf4w_moe_matmul_swiglu\", &model->f32_mf4w_moe_matmul_swiglu_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_mf4w_moe_matmul\", &model->f32_mf4w_moe_matmul_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_accumulate_e4\", &model->f32_accumulate_e4_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_topk_softmax_e32_k4\", &model->f32_topk_softmax_e32_k4_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_topk_softmax_e128_k4\", &model->f32_topk_softmax_e128_k4_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_softmax\", &model->f32_softmax_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_sample\", &model->f32_sample_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n    status = gptoss_metal_function_create(&model->library, \"gptoss_f32_sdpa_q8_d64\", &model->f32_sdpa_q8_d64_fn);\n    if (status != gptoss_status_success) {\n        goto cleanup;\n    }\n\n    // Kernel launch parameters\n    model->embeddings_threadgroup_size = 512;\n    model->attn_qkv_threadgroup_size = 1024;\n    model->attn_out_threadgroup_size = 768;\n    model->mlp_gate_threadgroup_size = 256;\n    model->mlp_swiglu_threadgroup_size = 192;\n    model->mlp_out_threadgroup_size = 192;\n    model->mlp_acc_threadgroup_size = 768;\n    model->unembedding_threadgroup_size = 416;\n\n    // Weight buffers\n    const char* current_ptr = (const char*) model->mapping_ptr;\n\n    const size_t embedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);\n    model->attn_rmsnorm_gain_offset = embedding_weight_size;\n    const size_t rmsnorm_weight_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);\n    model->attn_qkv_weight_offset = model->attn_rmsnorm_gain_offset + rmsnorm_weight_size;\n    const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);\n    const size_t attn_qkv_weight_size = math_round_up_po2(attn_qkv_dim * model->embedding_dim * sizeof(gptoss_bfloat16), 16);\n    model->attn_qkv_bias_offset = model->attn_qkv_weight_offset + attn_qkv_weight_size;\n    const size_t attn_qkv_bias_size = math_round_up_po2(attn_qkv_dim * sizeof(gptoss_bfloat16), 16);\n    model->attn_sdpa_sink_offset = model->attn_qkv_bias_offset + attn_qkv_bias_size;\n    const size_t attn_sink_weight_size = math_round_up_po2(model->num_heads * sizeof(gptoss_bfloat16), 16);\n    model->attn_out_weight_offset = model->attn_sdpa_sink_offset + attn_sink_weight_size;\n    const size_t attn_out_weight_size = math_round_up_po2(model->embedding_dim * model->num_heads * model->head_dim * sizeof(gptoss_bfloat16), 16);\n    model->attn_out_bias_offset = model->attn_out_weight_offset + attn_out_weight_size;\n    const size_t attn_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);\n    model->mlp_rmsnorm_gain_offset = model->attn_out_bias_offset + attn_out_bias_size;\n    model->mlp_gate_weight_offset = model->mlp_rmsnorm_gain_offset + rmsnorm_weight_size;\n    const size_t mlp_gate_weight_size = math_round_up_po2(model->num_experts * model->embedding_dim * sizeof(gptoss_bfloat16), 16);\n    model->mlp_gate_bias_offset = model->mlp_gate_weight_offset + mlp_gate_weight_size;\n    const size_t mlp_gate_bias_size = math_round_up_po2(model->num_experts * sizeof(gptoss_bfloat16), 16);\n    const size_t per_block_shared_weights_size =\n        rmsnorm_weight_size + attn_qkv_weight_size + attn_qkv_bias_size + attn_sink_weight_size + attn_out_weight_size + attn_out_bias_size +\n        rmsnorm_weight_size + mlp_gate_weight_size + mlp_gate_bias_size;\n    model->rmsnorm_weight_offset = embedding_weight_size + model->num_blocks * per_block_shared_weights_size;\n    model->unembedding_weight_offset = model->rmsnorm_weight_offset + rmsnorm_weight_size;\n    const size_t unembedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);\n\n    model->per_block_shared_weights_size = per_block_shared_weights_size;\n    const size_t shared_weights_size =\n        round_up_to_page_size(embedding_weight_size + rmsnorm_weight_size + unembedding_weight_size + model->num_blocks * per_block_shared_weights_size);\n\n    status = gptoss_metal_buffer_wrap(&model->device, shared_weights_size, current_ptr, &model->shared_weight_buffer);\n    if (status != gptoss_status_success) {\n        GPTOSS_LOG_ERROR(\"failed to map expert-shared weight of size %zu onto a Metal buffer\", shared_weights_size);\n        goto cleanup;\n    }\n    current_ptr += shared_weights_size;\n    model->weights_size += shared_weights_size;\n\n    const size_t mlp_swiglu_weight_block_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 2, 16);\n    model->mlp_swiglu_scale_offset = mlp_swiglu_weight_block_size;\n    const size_t mlp_swiglu_weight_scale_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 32, 16);\n    model->mlp_swiglu_bias_offset = model->mlp_swiglu_scale_offset + mlp_swiglu_weight_scale_size;\n    const size_t mlp_swiglu_bias_size = math_round_up_po2(2 * model->mlp_dim * sizeof(gptoss_bfloat16), 16);\n    model->mlp_out_block_offset = model->mlp_swiglu_bias_offset + mlp_swiglu_bias_size;\n    const size_t mlp_out_weight_block_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 2, 16);\n    model->mlp_out_scale_offset = model->mlp_out_block_offset + mlp_out_weight_block_size;\n    const size_t mlp_out_weight_scale_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 32, 16);\n    model->mlp_out_bias_offset = model->mlp_out_scale_offset + mlp_out_weight_scale_size;\n    const size_t mlp_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);\n    model->per_expert_block_weight_size =\n        mlp_swiglu_weight_block_size + mlp_swiglu_weight_scale_size + mlp_swiglu_bias_size + mlp_out_weight_block_size + mlp_out_weight_scale_size + mlp_out_bias_size;\n    const size_t moe_block_weight_size = round_up_to_page_size(model->num_experts * model->per_expert_block_weight_size);\n    for (uint32_t n = 0; n < model->num_blocks; n++) {\n        status = gptoss_metal_buffer_wrap(&model->device, moe_block_weight_size, current_ptr, &model->block_weight_buffers[n]);\n        if (status != gptoss_status_success) {\n            GPTOSS_LOG_ERROR(\"failed to map block #%\" PRIu32 \" MoE weight of size %zu onto a Metal buffer\",\n                n, moe_block_weight_size);\n            goto cleanup;\n        }\n        current_ptr += moe_block_weight_size;\n        model->weights_size += moe_block_weight_size;\n    }\n\n    // Commit tokenizer\n    model->tokenizer = tokenizer;\n    tokenizer = NULL;\n\n    // Commit model\n    *model_out = model;\n    model = NULL;\n\ncleanup:\n    if (fd != -1) {\n        close(fd);\n        fd = -1;\n    }\n    gptoss_model_release(model);  // does nothing if model is NULL\n    gptoss_tokenizer_release(tokenizer);  // does nothing if tokenizer is NULL\n    return status;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(\n    gptoss_model_t model,\n    gptoss_tokenizer_t* tokenizer_out)\n{\n    gptoss_tokenizer_t tokenizer = model->tokenizer;\n    atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);\n    *tokenizer_out = tokenizer;\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(\n    gptoss_model_t model,\n    size_t* max_context_length_out)\n{\n    *max_context_length_out = model->context_length;\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_model_retain(\n    gptoss_model_t model)\n{\n    atomic_fetch_add_explicit(&model->ref_count, 1, memory_order_relaxed);\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_model_release(\n    gptoss_model_t model)\n{\n    if (model != NULL) {\n        if (atomic_fetch_sub_explicit(&model->ref_count, 1, memory_order_acq_rel) == 1) {\n            gptoss_tokenizer_release(model->tokenizer);\n\n            // Weight buffers\n            gptoss_metal_buffer_release(&model->shared_weight_buffer);\n            for (uint32_t n = 0; n < model->num_blocks; n++) {\n                gptoss_metal_buffer_release(&model->block_weight_buffers[n]);\n            }\n\n            // Metal kernels\n            gptoss_metal_function_release(&model->bf16_f32_embeddings_fn);\n            gptoss_metal_function_release(&model->f32_bf16w_rmsnorm_fn);\n            gptoss_metal_function_release(&model->f32_bf16w_matmul_fn);\n            gptoss_metal_function_release(&model->f32_bf16w_matmul_qkv_fn);\n            gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_qkv_fn);\n            gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_attn_output_fn);\n            gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_mlp_gate_fn);\n            gptoss_metal_function_release(&model->f32_bf16w_unembedding_fn);\n            gptoss_metal_function_release(&model->f32_rope_fn);\n            gptoss_metal_function_release(&model->f32_expert_routing_metadata_fn);\n            gptoss_metal_function_release(&model->f32_scatter_e4_fn);\n            gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_swiglu_fn);\n            gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_fn);\n            gptoss_metal_function_release(&model->f32_gather_and_accumulate_e4_fn);\n            gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_swiglu_fn);\n            gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_fn);\n            gptoss_metal_function_release(&model->f32_accumulate_e4_fn);\n            gptoss_metal_function_release(&model->f32_topk_softmax_e32_k4_fn);\n            gptoss_metal_function_release(&model->f32_topk_softmax_e128_k4_fn);\n            gptoss_metal_function_release(&model->f32_softmax_fn);\n            gptoss_metal_function_release(&model->f32_sample_fn);\n            gptoss_metal_function_release(&model->f32_sdpa_q8_d64_fn);\n            gptoss_metal_library_release(&model->library);\n\n            gptoss_metal_command_queue_release(&model->command_queue);\n            gptoss_metal_device_release(&model->device);\n            // Weight buffers\n\n            if (model->mapping_ptr != NULL && model->mapping_size != 0) {\n                if (model->lock_memory) {\n                    if (munlock(model->mapping_ptr, model->mapping_size) != 0) {\n                        GPTOSS_LOG_WARNING(\"munlock for model weight mapping failed with error %d\", errno);\n                    }\n                }\n\n                if (munmap(model->mapping_ptr, model->mapping_size) != 0) {\n                    GPTOSS_LOG_WARNING(\"munmap for model weight mapping failed with error %d\", errno);\n                }\n            }\n\n            const size_t model_size = sizeof(struct gptoss_model) + model->num_blocks * sizeof(struct gptoss_metal_buffer);\n            memset(model, 0, model_size);\n            free(model);\n        }\n    }\n    return gptoss_status_success;\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/moematmul.metal",
    "content": "#include <internal/kernel-args.h>\n#include <metal_common>\n#include <metal_compute>\n#include <metal_math>\n#include <metal_simdgroup>\n#include <metal_stdlib>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n#define ceil_div(a, b) (((a) + (b) - 1) / (b))\n\n// Each simdgroup reduces all channels of the input and computes a single channel of the output\n// + Efficient synchronization\n// + Sequential memory access within a warp\n// Each threadgroup computes (simdgroups_per_threadgroup) consecutive output channels\n// + Reuse input vector from threadgroup memory\n// + Avoid synchronization across warps when doing reduction\n\nkernel void gptoss_f32_mf4w_moe_matmul_swiglu(\n    constant gptoss_moe_matmul_swiglu_args& args [[ buffer(0) ]],\n    const device float4* input [[ buffer(1) ]],\n    const device gptoss_expert_prediction* expert [[ buffer(2) ]],\n    const device uint4* weight_blocks [[ buffer(3) ]],\n    const device uchar* weight_scales [[ buffer(4) ]],\n    const device bfloat* bias [[ buffer(5) ]],\n    device float* output [[ buffer(6) ]],\n    const device gptoss_control* control [[ buffer(7) ]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_index_in_threadgroup]],\n    uint simdgroup_tid [[thread_index_in_simdgroup]],\n    uint simdgroup_idx [[simdgroup_index_in_threadgroup]],\n    uint num_simdgroups [[simdgroups_per_threadgroup]])\n{\n    const uint simdgroup_size = 32;\n    threadgroup float threadgroup_buffer[32];\n    if (control->abort != 0) {\n        return;\n    }\n\n    const uint num_column_vecs = args.num_column_vecs;\n    const uint row = gid.x * num_simdgroups + simdgroup_idx;\n    const uint expert_id = expert[gid.y * args.num_active_experts + gid.z].expert_id;\n\n    input += 8 * (gid.y * num_column_vecs + simdgroup_tid);\n    weight_blocks = (const device uint4*) ((uintptr_t) (weight_blocks + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);\n    weight_scales = (const device uchar*) ((uintptr_t) (weight_scales + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);\n    bias = (const device bfloat*) ((uintptr_t) (bias + row) + expert_id * args.weight_expert_stride);\n    output += gid.y * args.num_rows + gid.x * (num_simdgroups / 2) + gid.z * args.output_expert_stride;\n\n    uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;\n\n    float4 sum4 = 0.0f;\n    do {\n        const uint4 wblock = *weight_blocks;\n        const float wscale = as_type<float>(static_cast<uint>(*weight_scales) << 23);\n        uint4 wblock02468ACEGIKMOQSU = wblock + wblock;\n        uint4 wblock13579BDFHJLNPRTV = wblock >> 3;\n        wblock02468ACEGIKMOQSU &= 0x1E1E1E1Eu;\n        wblock13579BDFHJLNPRTV &= 0x1E1E1E1Eu;\n        wblock02468ACEGIKMOQSU += 0x70707070u;\n        wblock13579BDFHJLNPRTV += 0x70707070u;\n        wblock02468ACEGIKMOQSU &= 0x8E8E8E8Eu;\n        wblock13579BDFHJLNPRTV &= 0x8E8E8E8Eu;\n        const uint4 wblock26AEIMQU = wblock02468ACEGIKMOQSU & 0xFF00FF00u;\n        const uint4 wblock048CGKOS = (wblock02468ACEGIKMOQSU << 8) & 0xFF00FF00u;\n        const uint4 wblock37BFJNRV = wblock13579BDFHJLNPRTV & 0xFF00FF00u;\n        const uint4 wblock159DHLPT = (wblock13579BDFHJLNPRTV << 8) & 0xFF00FF00u;\n        const float4 w048C = static_cast<float4>(as_type<half4>(wblock048CGKOS.xy));\n        const float4 wGKOS = static_cast<float4>(as_type<half4>(wblock048CGKOS.zw));\n        const float4 w26AE = static_cast<float4>(as_type<half4>(wblock26AEIMQU.xy));\n        const float4 wIMQU = static_cast<float4>(as_type<half4>(wblock26AEIMQU.zw));\n        const float4 w159D = static_cast<float4>(as_type<half4>(wblock159DHLPT.xy));\n        const float4 wHLPT = static_cast<float4>(as_type<half4>(wblock159DHLPT.zw));\n        const float4 w37BF = static_cast<float4>(as_type<half4>(wblock37BFJNRV.xy));\n        const float4 wJNRV = static_cast<float4>(as_type<half4>(wblock37BFJNRV.zw));\n\n        const float4 w0123 = (float4) { w048C.x, w159D.x, w26AE.x, w37BF.x };\n        const float4 w4567 = (float4) { w048C.y, w159D.y, w26AE.y, w37BF.y };\n        const float4 w89AB = (float4) { w048C.z, w159D.z, w26AE.z, w37BF.z };\n        const float4 wCDEF = (float4) { w048C.w, w159D.w, w26AE.w, w37BF.w };\n        const float4 wGHIJ = (float4) { wGKOS.x, wHLPT.x, wIMQU.x, wJNRV.x };\n        const float4 wKLMN = (float4) { wGKOS.y, wHLPT.y, wIMQU.y, wJNRV.y };\n        const float4 wOPQR = (float4) { wGKOS.z, wHLPT.z, wIMQU.z, wJNRV.z };\n        const float4 wSTUV = (float4) { wGKOS.w, wHLPT.w, wIMQU.w, wJNRV.w };\n\n        const float4 i0123 = input[0];\n        const float4 i4567 = input[1];\n        const float4 i89AB = input[2];\n        const float4 iCDEF = input[3];\n        const float4 iGHIJ = input[4];\n        const float4 iKLMN = input[5];\n        const float4 iOPQR = input[6];\n        const float4 iSTUV = input[7];\n\n        float4 psum0 = i0123 * w0123;\n        float4 psum1 = i4567 * w4567;\n        psum0 = metal::fma(i89AB, w89AB, psum0);\n        psum1 = metal::fma(iCDEF, wCDEF, psum1);\n        psum0 = metal::fma(iGHIJ, wGHIJ, psum0);\n        psum1 = metal::fma(iKLMN, wKLMN, psum1);\n        psum0 = metal::fma(iOPQR, wOPQR, psum0);\n        psum1 = metal::fma(iSTUV, wSTUV, psum1);\n        sum4 = metal::fma(psum0, wscale, sum4);\n        sum4 = metal::fma(psum1, wscale, sum4);\n\n        weight_blocks += simdgroup_size;\n        weight_scales += simdgroup_size;\n        input += 8 * simdgroup_size;\n    } while (--num_iter != 0);\n    const float2 sum2 = sum4.xy + sum4.zw;\n    float sum = sum2.x + sum2.y;\n    sum = metal::simd_sum(sum);\n    if (metal::simd_is_first()) {\n        sum += static_cast<float>(*bias);\n        threadgroup_buffer[simdgroup_idx] = sum;\n    }\n    metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    if (tid * 2 < num_simdgroups) {\n        const float2 x = reinterpret_cast<const threadgroup float2*>(threadgroup_buffer)[tid];\n        const float swish_x = metal::min(x.x, args.swiglu_max);\n        const float linear_x = metal::clamp(x.y, args.swiglu_min, args.swiglu_max);\n        const float alpha = 1.702f;\n        const float swish_y = swish_x / (1.0f + metal::precise::exp(-alpha * swish_x));\n        const float swiglu_y = metal::fma(swish_y, linear_x, swish_y);\n        output[tid] = swiglu_y;\n    }\n}\n\nkernel void gptoss_f32_mf4w_moe_matmul(\n    constant gptoss_moe_matmul_args& args [[ buffer(0) ]],\n    const device float4* input [[ buffer(1) ]],\n    const device gptoss_expert_prediction* expert [[ buffer(2) ]],\n    const device uint4* weight_blocks [[ buffer(3) ]],\n    const device uchar* weight_scales [[ buffer(4) ]],\n    const device bfloat* bias [[ buffer(5) ]],\n    device float* output [[ buffer(6) ]],\n    const device gptoss_control* control [[ buffer(7) ]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_index_in_threadgroup]],\n    uint simdgroup_tid [[thread_index_in_simdgroup]],\n    uint simdgroup_idx [[simdgroup_index_in_threadgroup]],\n    uint num_simdgroups [[simdgroups_per_threadgroup]])\n{\n    const uint simdgroup_size = 32;\n    if (control->abort != 0) {\n        return;\n    }\n\n    const uint num_column_vecs = args.num_column_vecs;\n    const uint row = gid.x * num_simdgroups + simdgroup_idx;\n    const uint expert_id = expert[gid.y * args.num_active_experts + gid.z].expert_id;\n\n    input += 8 * (gid.y * num_column_vecs + simdgroup_tid + gid.z * args.input_expert_stride);\n    weight_blocks = (const device uint4*) ((uintptr_t) (weight_blocks + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);\n    weight_scales = (const device uchar*) ((uintptr_t) (weight_scales + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);\n    bias = (const device bfloat*) ((uintptr_t) (bias + row) + expert_id * args.weight_expert_stride);\n    output += gid.y * args.num_rows + row + gid.z * args.output_expert_stride;\n\n    uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;\n\n    float4 sum4 = 0.0f;\n    do {\n        const uint4 wblock = *weight_blocks;\n        const float wscale = as_type<float>(static_cast<uint>(*weight_scales) << 23);\n        uint4 wblock02468ACEGIKMOQSU = wblock + wblock;\n        uint4 wblock13579BDFHJLNPRTV = wblock >> 3;\n        wblock02468ACEGIKMOQSU &= 0x1E1E1E1Eu;\n        wblock13579BDFHJLNPRTV &= 0x1E1E1E1Eu;\n        wblock02468ACEGIKMOQSU += 0x70707070u;\n        wblock13579BDFHJLNPRTV += 0x70707070u;\n        wblock02468ACEGIKMOQSU &= 0x8E8E8E8Eu;\n        wblock13579BDFHJLNPRTV &= 0x8E8E8E8Eu;\n        const uint4 wblock26AEIMQU = wblock02468ACEGIKMOQSU & 0xFF00FF00u;\n        const uint4 wblock048CGKOS = (wblock02468ACEGIKMOQSU << 8) & 0xFF00FF00u;\n        const uint4 wblock37BFJNRV = wblock13579BDFHJLNPRTV & 0xFF00FF00u;\n        const uint4 wblock159DHLPT = (wblock13579BDFHJLNPRTV << 8) & 0xFF00FF00u;\n        const float4 w048C = static_cast<float4>(as_type<half4>(wblock048CGKOS.xy));\n        const float4 wGKOS = static_cast<float4>(as_type<half4>(wblock048CGKOS.zw));\n        const float4 w26AE = static_cast<float4>(as_type<half4>(wblock26AEIMQU.xy));\n        const float4 wIMQU = static_cast<float4>(as_type<half4>(wblock26AEIMQU.zw));\n        const float4 w159D = static_cast<float4>(as_type<half4>(wblock159DHLPT.xy));\n        const float4 wHLPT = static_cast<float4>(as_type<half4>(wblock159DHLPT.zw));\n        const float4 w37BF = static_cast<float4>(as_type<half4>(wblock37BFJNRV.xy));\n        const float4 wJNRV = static_cast<float4>(as_type<half4>(wblock37BFJNRV.zw));\n\n        const float4 w0123 = (float4) { w048C.x, w159D.x, w26AE.x, w37BF.x };\n        const float4 w4567 = (float4) { w048C.y, w159D.y, w26AE.y, w37BF.y };\n        const float4 w89AB = (float4) { w048C.z, w159D.z, w26AE.z, w37BF.z };\n        const float4 wCDEF = (float4) { w048C.w, w159D.w, w26AE.w, w37BF.w };\n        const float4 wGHIJ = (float4) { wGKOS.x, wHLPT.x, wIMQU.x, wJNRV.x };\n        const float4 wKLMN = (float4) { wGKOS.y, wHLPT.y, wIMQU.y, wJNRV.y };\n        const float4 wOPQR = (float4) { wGKOS.z, wHLPT.z, wIMQU.z, wJNRV.z };\n        const float4 wSTUV = (float4) { wGKOS.w, wHLPT.w, wIMQU.w, wJNRV.w };\n\n        const float4 i0123 = input[0];\n        const float4 i4567 = input[1];\n        const float4 i89AB = input[2];\n        const float4 iCDEF = input[3];\n        const float4 iGHIJ = input[4];\n        const float4 iKLMN = input[5];\n        const float4 iOPQR = input[6];\n        const float4 iSTUV = input[7];\n\n        float4 psum0 = i0123 * w0123;\n        float4 psum1 = i4567 * w4567;\n        psum0 = metal::fma(i89AB, w89AB, psum0);\n        psum1 = metal::fma(iCDEF, wCDEF, psum1);\n        psum0 = metal::fma(iGHIJ, wGHIJ, psum0);\n        psum1 = metal::fma(iKLMN, wKLMN, psum1);\n        psum0 = metal::fma(iOPQR, wOPQR, psum0);\n        psum1 = metal::fma(iSTUV, wSTUV, psum1);\n        sum4 = metal::fma(psum0, wscale, sum4);\n        sum4 = metal::fma(psum1, wscale, sum4);\n\n        weight_blocks += simdgroup_size;\n        weight_scales += simdgroup_size;\n        input += 8 * simdgroup_size;\n    } while (--num_iter != 0);\n    const float2 sum2 = sum4.xy + sum4.zw;\n    float sum = sum2.x + sum2.y;\n    sum = metal::simd_sum(sum);\n    if (metal::simd_is_first()) {\n        sum += static_cast<float>(*bias);\n        *output = sum;\n    }\n}\n\nkernel void gptoss_f32_mf4w_moe_dense_matmul_swiglu(\n    constant gptoss_moe_dense_matmul_swiglu_args& params [[ buffer(0) ]],\n    const device uint* __restrict__ expert_offsets [[ buffer(1) ]],\n    const device float* lhs [[ buffer(2) ]],\n    const device uint* weight_blocks [[ buffer(3) ]],\n    const device uchar* weight_scales [[ buffer(4) ]],\n    const device bfloat* __restrict__ bias [[ buffer(5) ]],\n    device float* out [[ buffer(6) ]],\n    uint sg_id [[simdgroup_index_in_threadgroup]],\n    uint3 threads_per_tg [[threads_per_threadgroup]],\n    uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],\n    uint3 gid [[thread_position_in_grid]],\n    uint3 tg_id [[threadgroup_position_in_grid]],\n    uint3 local_tid [[thread_position_in_threadgroup]]) \n{\n    constexpr uint Bm = MOE_DENSE_MATMUL_SWIGLU_Bm;\n    constexpr uint Bn = MOE_DENSE_MATMUL_SWIGLU_Bn;\n    constexpr uint Bk = MOE_DENSE_MATMUL_SWIGLU_Bk;\n    constexpr uint Sg_Bm = MOE_DENSE_MATMUL_SWIGLU_Sg_Bm;\n    constexpr uint Sg_Bn = MOE_DENSE_MATMUL_SWIGLU_Sg_Bn;\n\n    // Assumptions about shapes.\n    assert(Bm % 8 == 0);\n    assert(Bn % 8 == 0);\n    assert(Bk % 8 == 0);\n    assert(Sg_Bm % 8 == 0);\n    assert(Sg_Bn % 8 == 0);\n    assert(Bm % Sg_Bm == 0);\n    assert(Bn % Sg_Bn == 0);\n\n    const uint K = params.k;\n    const uint N = params.n;\n    const uint M = expert_offsets[tg_id.z + 1] - expert_offsets[tg_id.z];\n    assert((K % 32) == 0);\n    assert((K % 8) == 0);\n    assert(N % Bn == 0);\n    assert(K % Bk == 0);\n    // Get row and col tg.\n    const uint row_tg = tg_id.y;\n    const uint col_tg = tg_id.x;\n    // Get row and col local tid.\n    const uint row_tg_offset = row_tg * Bm;\n    const uint col_tg_offset = col_tg * Bn;\n    if (row_tg_offset >= M || col_tg_offset >= N) {\n        return;\n    }\n    // Move lhs and output according to the passed offset.\n    const uint expert_offset = expert_offsets[tg_id.z];\n    lhs += expert_offset * K;\n    const uint N_output = N / 2;\n    out += expert_offset * N_output;\n\n    const uint S = params.weight_blocks_expert_stride_bytes;\n    const uint S_scales = params.weight_scales_expert_stride_bytes;\n    const uint S_bias = params.bias_expert_stride_bytes;\n\n    const device char* wb0 = reinterpret_cast<const device char*>(weight_blocks);\n    const device char* sc0 = reinterpret_cast<const device char*>(weight_scales);\n    const device char* bi0 = reinterpret_cast<const device char*>(bias);\n\n    weight_blocks = reinterpret_cast<const device uint*>(wb0 + tg_id.z * S);\n    weight_scales = reinterpret_cast<const device uchar*>(sc0 + tg_id.z * S_scales);\n    bias = reinterpret_cast<const device bfloat*>(bi0 + tg_id.z * S_bias);\n\n    const uint sg_col_count = Bn / Sg_Bn;\n    const uint row_sg = sg_id / sg_col_count;\n    const uint col_sg = sg_id % sg_col_count;\n\n    const uint row_sg_offset = row_sg * Sg_Bm;\n    const uint col_sg_offset = col_sg * Sg_Bn;\n    // Declare threadgroup blocks.\n    threadgroup float lhs_block[Bm * Bk];\n    // rhs_block will hold the scaled fp32 weights.\n    threadgroup float rhs_block[Bn * Bk];\n\n    constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);\n    // Create an array of simdgroup_float8x8 to hold temp results.\n    metal::simdgroup_float8x8 OutTiles[temp_result_size];\n    for (uint i = 0; i < temp_result_size; i++) {\n        OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(0.0);\n    }\n    // Linear thread id within TG (we launch 1-D TGs)\n    const uint lin_tid = local_tid.x;\n    const uint thread_count_per_tg = threads_per_tg.x * threads_per_tg.y * threads_per_tg.z;\n\n    // Iterate over all Bk blocks.\n    for (uint k_offset = 0; k_offset < K; k_offset += Bk) {\n        constexpr uint lhs_row_stride = Bk;\n        constexpr uint lhs_vec_cols = Bk / 4;\n        constexpr uint lhs_vec_total = Bm * lhs_vec_cols;\n\n        const uint LHS_ITERS = ceil_div(lhs_vec_total, thread_count_per_tg);\n\n        // #pragma clang loop unroll(full)\n        for (uint t = 0; t < LHS_ITERS; ++t) {\n            const uint i = t * thread_count_per_tg + lin_tid;\n            if (i < lhs_vec_total) {\n                const uint r = i / lhs_vec_cols;\n                const uint c4 = i % lhs_vec_cols;\n\n                const uint gr = row_tg_offset + r;\n                const uint gc4 = (k_offset / 4) + c4;\n\n                threadgroup float4* dst4 =\n                    reinterpret_cast<threadgroup float4*>(lhs_block + r * lhs_row_stride + (c4 << 2));\n                if (gr < M) {\n                    const device float4* src4 =\n                        reinterpret_cast<const device float4*>(lhs + gr * K + (gc4 << 2));\n\n                    *dst4 = *src4;\n                } else {\n                    *dst4 = float4(0.0);\n                }\n            }\n        }\n\n        // Load weights with vector loads.\n        constexpr uint rhs_row_stride = Bk;\n        constexpr uint weights_per_elem = 8;\n        constexpr uint rhs_loads_per_col = Bk / weights_per_elem;\n        constexpr uint rhs_loads_total = Bn * rhs_loads_per_col;\n        const uint RHS_ITERS = ceil_div(rhs_loads_total, thread_count_per_tg);\n        // #pragma clang loop unroll(full)\n        for (uint t = 0; t < RHS_ITERS; ++t) {\n            const uint i = t * thread_count_per_tg + lin_tid;\n            if (i < rhs_loads_total) {\n                const uint r = i / rhs_loads_per_col;\n                const uint c = i % rhs_loads_per_col;\n\n                const uint gr = col_tg_offset + r;\n                const uint gc = (k_offset / weights_per_elem) + c;\n                const uint gc_scale = (k_offset / 32) + (c >> 2);\n\n                const uint wblock = weight_blocks[gr * (K / weights_per_elem) + gc];\n                const float scale =\n                    as_type<float>(static_cast<uint>(weight_scales[gr * (K / 32) + gc_scale]) << 23);\n                uint wblock0246 = (wblock + wblock);\n                uint wblock1357 = (wblock >> 3);\n                wblock0246 &= 0x1E1E1E1Eu;\n                wblock1357 &= 0x1E1E1E1Eu;\n\n                wblock0246 += 0x70707070u;\n                wblock1357 += 0x70707070u;\n                wblock0246 &= 0x8E8E8E8Eu;\n                wblock1357 &= 0x8E8E8E8Eu;\n\n                uint wblock26 = (wblock0246) & 0xFF00FF00u;\n                uint wblock04 = ((wblock0246 << 8)) & 0xFF00FF00u;\n                uint wblock37 = (wblock1357) & 0xFF00FF00u;\n                uint wblock15 = ((wblock1357 << 8)) & 0xFF00FF00u;\n\n                half4 wblock0426 = as_type<half4>(uint2(wblock04, wblock26));\n                half4 wblock1537 = as_type<half4>(uint2(wblock15, wblock37));\n\n                // Convert to float scalars and apply scale\n                const float w0 = float(wblock0426.x) * scale;\n                const float w1 = float(wblock1537.x) * scale;\n                const float w2 = float(wblock0426.z) * scale;\n                const float w3 = float(wblock1537.z) * scale;\n                const float w4 = float(wblock0426.y) * scale;\n                const float w5 = float(wblock1537.y) * scale;\n                const float w6 = float(wblock0426.w) * scale;\n                const float w7 = float(wblock1537.w) * scale;\n                const uint rhs_offset = r * rhs_row_stride + c * 8;\n                rhs_block[rhs_offset] = w0;\n                rhs_block[rhs_offset + 1] = w1;\n                rhs_block[rhs_offset + 2] = w2;\n                rhs_block[rhs_offset + 3] = w3;\n                rhs_block[rhs_offset + 4] = w4;\n                rhs_block[rhs_offset + 5] = w5;\n                rhs_block[rhs_offset + 6] = w6;\n                rhs_block[rhs_offset + 7] = w7;\n            }\n        }\n        threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n#pragma clang loop unroll(full)\n        for (uint k = 0; k < Bk; k += 8) {\n#pragma clang loop unroll(full)\n            for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {\n                const uint row_index_in_out_tile = m_subtile_ / 8;\n                metal::simdgroup_float8x8 lhs_frag;\n\n                simdgroup_load(lhs_frag, lhs_block, Bk, ulong2(k, m_subtile_ + row_sg_offset));\n#pragma clang loop unroll(full)\n                for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {\n                    const uint col_index_in_out_tile = n_subtile_ / 8;\n                    const uint current_index_out_tile =\n                        row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;\n                    metal::simdgroup_float8x8 rhs_frag;\n                    simdgroup_load(rhs_frag, rhs_block, Bk, ulong2(k, n_subtile_ + col_sg_offset), true);\n\n                    simdgroup_multiply_accumulate(OutTiles[current_index_out_tile], lhs_frag, rhs_frag,\n                        OutTiles[current_index_out_tile]);\n                }\n            }\n        }\n        threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    }\n\n    // Epilogue.\n    threadgroup float scratch[Bm * Bn];\n#pragma clang loop unroll(full)\n    for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {\n        const uint col_index_in_out_tile = n_subtile_ / 8;\n        const uint local_col_offset = col_sg_offset + n_subtile_;\n#pragma clang loop unroll(full)\n        for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {\n            const uint row_index_in_out_tile = m_subtile_ / 8;\n            const uint local_row_offset = row_sg_offset + m_subtile_;\n            const uint current_index_out_tile =\n                row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;\n            simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,\n                ulong2(local_col_offset, local_row_offset));\n        }\n    }\n    threadgroup float bias_tile[Bn];\n    // TODO(ibahmed): vectorize these loads an maybe unroll the loop.\n    for (uint c_local = local_tid.x; c_local < Bn; c_local += thread_count_per_tg) {\n        const uint c_global = col_tg_offset + c_local;\n        bias_tile[c_local] = (c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;\n    }\n\n    threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    const float alpha = 1.702f;\n    // TODO(ibahmed): vectorize these stores and maybe unroll the loop.\n    for (uint idx = local_tid.x; idx < Bm * Bn / 2; idx += thread_count_per_tg) {\n        const uint idx_swish = idx * 2;\n        const uint r = idx_swish / Bn;\n        const uint c_swish = idx_swish % Bn;\n\n        const uint out_row = row_tg_offset + r;\n        const uint out_col = (col_tg_offset / 2) + (c_swish / 2);\n\n        if (out_row < M && out_col < N_output) {\n            float acc_swish = scratch[idx_swish] + bias_tile[c_swish];\n            float acc_linear = scratch[idx_swish + 1] + bias_tile[c_swish + 1];\n            const float swish = metal::min(acc_swish, params.swiglu_max);\n            const float linear = metal::clamp(acc_linear, params.swiglu_min, params.swiglu_max);\n            const float swish_y = swish / (1.0f + metal::precise::exp(-alpha * swish));\n            const float swiglu_y = metal::fma(swish_y, linear, swish_y);\n            out[out_row * N_output + out_col] = swiglu_y;\n        }\n    }\n}\n\nkernel void gptoss_f32_mf4w_moe_dense_matmul(\n    constant gptoss_moe_dense_matmul_args& params [[ buffer(0) ]],\n    const device uint* __restrict__ expert_offsets [[ buffer(1) ]],\n    const device float* lhs [[ buffer(2) ]],\n    const device uint* weight_blocks [[ buffer(3) ]],\n    const device uchar* weight_scales [[ buffer(4) ]],\n    const device bfloat* __restrict__ bias [[ buffer(5) ]],\n    device float* out [[ buffer(6) ]],\n    uint sg_id [[simdgroup_index_in_threadgroup]],\n    uint3 threads_per_tg [[threads_per_threadgroup]],\n    uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],\n    uint3 gid [[thread_position_in_grid]],\n    uint3 tg_id [[threadgroup_position_in_grid]],\n    uint3 local_tid [[thread_position_in_threadgroup]]) \n{\n    const uint Bm = MOE_DENSE_MATMUL_Bm;\n    const uint Bn = MOE_DENSE_MATMUL_Bn;\n    const uint Bk = MOE_DENSE_MATMUL_Bk;\n    const uint Sg_Bm = MOE_DENSE_MATMUL_Sg_Bm;\n    const uint Sg_Bn = MOE_DENSE_MATMUL_Sg_Bn;\n    assert(Bm % 8 == 0);\n    assert(Bn % 8 == 0);\n    assert(Bk % 8 == 0);\n    assert(Sg_Bm % 8 == 0);\n    assert(Sg_Bn % 8 == 0);\n    assert(Bm % Sg_Bm == 0);\n    assert(Bn % Sg_Bn == 0);\n\n    const uint K = params.k;\n    const uint N = params.n;\n    const uint M = expert_offsets[tg_id.z + 1] - expert_offsets[tg_id.z];\n    assert((K % 32) == 0);\n    assert((K % 8) == 0);\n    assert(N % Bn == 0);\n    assert(K % Bk == 0);\n    // Get row and col tg.\n    const uint row_tg = tg_id.y;\n    const uint col_tg = tg_id.x;\n    // Get row and col local tid.\n    const uint row_tg_offset = row_tg * Bm;\n    const uint col_tg_offset = col_tg * Bn;\n    if (row_tg_offset >= M || col_tg_offset >= N) {\n        return;\n    }\n    // Move lhs and output according to the passed offset.\n    const uint expert_offset = expert_offsets[tg_id.z];\n    lhs += expert_offset * K;\n    out += expert_offset * N;\n\n    const uint S = params.weight_blocks_expert_stride_bytes;\n    const uint S_scales = params.weight_scales_expert_stride_bytes;\n    const uint S_bias = params.bias_expert_stride_bytes;\n\n    const device char* wb0 = reinterpret_cast<const device char*>(weight_blocks);\n    const device char* sc0 = reinterpret_cast<const device char*>(weight_scales);\n    const device char* bi0 = reinterpret_cast<const device char*>(bias);\n\n    weight_blocks = reinterpret_cast<const device uint*>(wb0 + tg_id.z * S);\n    weight_scales = reinterpret_cast<const device uchar*>(sc0 + tg_id.z * S_scales);\n    bias = reinterpret_cast<const device bfloat*>(bi0 + tg_id.z * S_bias);\n\n    const uint sg_col_count = Bn / Sg_Bn;\n    const uint row_sg = sg_id / sg_col_count;\n    const uint col_sg = sg_id % sg_col_count;\n\n    const uint row_sg_offset = row_sg * Sg_Bm;\n    const uint col_sg_offset = col_sg * Sg_Bn;\n    // Declare threadgroup blocks.\n    threadgroup float lhs_block[Bm * Bk];\n    // rhs_block will hold the scaled fp32 weights.\n    threadgroup float rhs_block[Bn * Bk];\n\n    constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);\n    // Create an array of simdgroup_float8x8 to hold temp results.\n    metal::simdgroup_float8x8 OutTiles[temp_result_size];\n    for (uint i = 0; i < temp_result_size; i++) {\n        OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(0.0);\n    }\n    // Linear thread id within TG (we launch 1-D TGs)\n    const uint lin_tid = local_tid.x;\n\n    const uint thread_count_per_tg = threads_per_tg.x * threads_per_tg.y * threads_per_tg.z;\n    // Iterate over all Bk blocks.\n    for (uint k_offset = 0; k_offset < K; k_offset += Bk) {\n        constexpr uint lhs_row_stride = Bk;\n        constexpr uint lhs_vec_cols = Bk / 4;\n        constexpr uint lhs_vec_total = Bm * lhs_vec_cols;\n\n        const uint LHS_ITERS = ceil_div(lhs_vec_total, thread_count_per_tg);\n\n        for (uint t = 0; t < LHS_ITERS; ++t) {\n            const uint i = t * thread_count_per_tg + lin_tid;\n            if (i < lhs_vec_total) {\n                const uint r = i / lhs_vec_cols;\n                const uint c4 = i % lhs_vec_cols;\n\n                const uint gr = row_tg_offset + r;\n                const uint gc4 = (k_offset / 4) + c4;\n\n                threadgroup float4* dst4 =\n                    reinterpret_cast<threadgroup float4*>(lhs_block + r * lhs_row_stride + (c4 << 2));\n                if (gr < M) {\n                    const device float4* src4 =\n                        reinterpret_cast<const device float4*>(lhs + gr * K + (gc4 << 2));\n\n                    *dst4 = *src4;\n                } else {\n                    *dst4 = float4(0.0);\n                }\n            }\n        }\n\n        // Load weights with vector loads.\n        constexpr uint rhs_row_stride = Bk;\n        constexpr uint weights_per_elem = 8;\n        constexpr uint rhs_loads_per_col = Bk / weights_per_elem;\n        constexpr uint rhs_loads_total = Bn * rhs_loads_per_col;\n        const uint RHS_ITERS = ceil_div(rhs_loads_total, thread_count_per_tg);\n        // #pragma clang loop unroll(full)\n        for (uint t = 0; t < RHS_ITERS; ++t) {\n            const uint i = t * thread_count_per_tg + lin_tid;\n            if (i < rhs_loads_total) {\n                const uint r = i / rhs_loads_per_col;\n                const uint c = i % rhs_loads_per_col;\n\n                const uint gr = col_tg_offset + r;\n                const uint gc = (k_offset / weights_per_elem) + c;\n                const uint gc_scale = (k_offset / 32) + (c >> 2);\n\n                const uint wblock = weight_blocks[gr * (K / weights_per_elem) + gc];\n                const float scale =\n                    as_type<float>(static_cast<uint>(weight_scales[gr * (K / 32) + gc_scale]) << 23);\n\n                uint wblock0246 = (wblock + wblock);\n                uint wblock1357 = (wblock >> 3);\n                wblock0246 &= 0x1E1E1E1Eu;\n                wblock1357 &= 0x1E1E1E1Eu;\n\n                wblock0246 += 0x70707070u;\n                wblock1357 += 0x70707070u;\n                wblock0246 &= 0x8E8E8E8Eu;\n                wblock1357 &= 0x8E8E8E8Eu;\n\n                uint wblock26 = (wblock0246) & 0xFF00FF00u;\n                uint wblock04 = ((wblock0246 << 8)) & 0xFF00FF00u;\n                uint wblock37 = (wblock1357) & 0xFF00FF00u;\n                uint wblock15 = ((wblock1357 << 8)) & 0xFF00FF00u;\n\n                half4 wblock0426 = as_type<half4>(uint2(wblock04, wblock26));\n                half4 wblock1537 = as_type<half4>(uint2(wblock15, wblock37));\n\n                const float w0 = float(wblock0426.x) * scale;\n                const float w1 = float(wblock1537.x) * scale;\n                const float w2 = float(wblock0426.z) * scale;\n                const float w3 = float(wblock1537.z) * scale;\n                const float w4 = float(wblock0426.y) * scale;\n                const float w5 = float(wblock1537.y) * scale;\n                const float w6 = float(wblock0426.w) * scale;\n                const float w7 = float(wblock1537.w) * scale;\n                const uint rhs_offset = r * rhs_row_stride + c * 8;\n                rhs_block[rhs_offset] = w0;\n                rhs_block[rhs_offset + 1] = w1;\n                rhs_block[rhs_offset + 2] = w2;\n                rhs_block[rhs_offset + 3] = w3;\n                rhs_block[rhs_offset + 4] = w4;\n                rhs_block[rhs_offset + 5] = w5;\n                rhs_block[rhs_offset + 6] = w6;\n                rhs_block[rhs_offset + 7] = w7;\n            }\n        }\n        threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n#pragma clang loop unroll(full)\n        for (uint k = 0; k < Bk; k += 8) {\n#pragma clang loop unroll(full)\n            for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {\n                const uint row_index_in_out_tile = m_subtile_ / 8;\n                metal::simdgroup_float8x8 lhs_frag;\n\n                simdgroup_load(lhs_frag, lhs_block, Bk, ulong2(k, m_subtile_ + row_sg_offset));\n#pragma clang loop unroll(full)\n                for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {\n                    const uint col_index_in_out_tile = n_subtile_ / 8;\n                    const uint current_index_out_tile =\n                        row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;\n                    metal::simdgroup_float8x8 rhs_frag;\n                    simdgroup_load(rhs_frag, rhs_block, Bk, ulong2(k, n_subtile_ + col_sg_offset), true);\n                    simdgroup_multiply_accumulate(OutTiles[current_index_out_tile], lhs_frag, rhs_frag,\n                        OutTiles[current_index_out_tile]);\n                }\n            }\n        }\n        threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    }\n\n    // Epilogue.\n    threadgroup float scratch[Bm * Bn];\n#pragma clang loop unroll(full)\n    for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {\n        const uint col_index_in_out_tile = n_subtile_ / 8;\n        const uint local_col_offset = col_sg_offset + n_subtile_;\n#pragma clang loop unroll(full)\n        for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {\n            const uint row_index_in_out_tile = m_subtile_ / 8;\n            const uint local_row_offset = row_sg_offset + m_subtile_;\n            const uint current_index_out_tile =\n                row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;\n            simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,\n                ulong2(local_col_offset, local_row_offset));\n        }\n    }\n    threadgroup float bias_tile[Bn];\n    for (uint c_local = local_tid.x; c_local < Bn; c_local += thread_count_per_tg) {\n        const uint c_global = col_tg_offset + c_local;\n        bias_tile[c_local] = (c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;\n    }\n\n    threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    for (uint idx = local_tid.x; idx < Bm * Bn; idx += thread_count_per_tg) {\n        const uint r = idx / Bn;\n        const uint c = idx % Bn;\n\n        const uint out_row = row_tg_offset + r;\n        const uint out_col = col_tg_offset + c;\n\n        if (out_row < M && out_col < N) {\n            float acc = scratch[idx] + bias_tile[c];\n            out[out_row * N + out_col] = acc;\n        }\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/random.metal",
    "content": "#include <metal_integer>\n#include <metal_math>\n\n#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n\ninline static uint rng_squares32(ulong offset, ulong seed) {\n    const ulong y = offset * seed;\n    const ulong z = y + seed;\n\n    /* Round 1 */\n    ulong x = y * y + y;\n    x = metal::rotate(x, 32ul);\n\n    /* Round 2 */\n    x = x * x + z;\n    x = metal::rotate(x, 32ul);\n\n    /* Round 3 */\n    x = x * x + y;\n    x = metal::rotate(x, 32ul);\n\n    /* Round 4 */\n    x = x * x + z;\n    return as_type<uint2>(x).y;\n}\n\nkernel void gptoss_u32_fill_random(\n    constant gptoss_u32_fill_random_args& args [[ buffer(0) ]],\n    device uint* output [[ buffer(1) ]],\n    uint gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_position_in_threadgroup]],\n    uint threadgroup_size [[ threads_per_threadgroup ]])\n{\n    const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;\n    const ulong threadgroup_start = gid * num_vecs_per_threadgroup;\n    const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);\n    const ulong thread_start = threadgroup_start + tid;\n    uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);\n\n    output += thread_start;\n    ulong offset = args.offset + thread_start;\n    for (; num_iter != 0; num_iter--) {\n        *output = rng_squares32(offset, args.seed);\n        output += threadgroup_size;\n        offset += threadgroup_size;\n    }\n}\n\nkernel void gptoss_f32_fill_random(\n    constant gptoss_f32_fill_random_args& args [[ buffer(0) ]],\n    device float* output [[ buffer(1) ]],\n    uint gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_position_in_threadgroup]],\n    uint threadgroup_size [[ threads_per_threadgroup ]])\n{\n    const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;\n    const ulong threadgroup_start = gid * num_vecs_per_threadgroup;\n    const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);\n    const ulong thread_start = threadgroup_start + tid;\n    uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);\n\n    output += thread_start;\n    ulong offset = args.offset + thread_start;\n    for (; num_iter != 0; num_iter--) {\n        const uint word = rng_squares32(offset, args.seed);\n        *output = metal::fma(static_cast<float>(as_type<int>(word)), args.scale, args.bias);\n        output += threadgroup_size;\n        offset += threadgroup_size;\n    }\n}\n\nkernel void gptoss_bf16_fill_random(\n    constant gptoss_f32_fill_random_args& args [[ buffer(0) ]],\n    device bfloat* output [[ buffer(1) ]],\n    uint gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_position_in_threadgroup]],\n    uint threadgroup_size [[ threads_per_threadgroup ]])\n{\n    const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;\n    const ulong threadgroup_start = gid * num_vecs_per_threadgroup;\n    const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);\n    const ulong thread_start = threadgroup_start + tid;\n    uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);\n\n    output += thread_start;\n    ulong offset = args.offset + thread_start;\n    for (; num_iter != 0; num_iter--) {\n        const uint word = rng_squares32(offset, args.seed);\n        *output = static_cast<bfloat>(metal::fma(static_cast<float>(as_type<int>(word)), args.scale, args.bias));\n        output += threadgroup_size;\n        offset += threadgroup_size;\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/rmsnorm.metal",
    "content": "#include <metal_compute>\n#include <metal_math>\n#include <metal_simdgroup>\n\n#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n\n[[max_total_threads_per_threadgroup(1024)]]\nkernel void gptoss_f32_bf16w_rmsnorm(\n    constant gptoss_rmsnorm_args& args [[ buffer(0) ]],\n    const device float4* input [[ buffer(1) ]],\n    const device bfloat4* weights [[ buffer(2) ]],\n    device float4* output [[ buffer(3) ]],\n    const device gptoss_control* control [[ buffer(4) ]],\n    uint gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_position_in_threadgroup]],\n    uint threadgroup_size [[ threads_per_threadgroup ]])\n{\n    const uint simdgroup_size = 32;\n    threadgroup float threadgroup_buffer[32];\n    if (control->abort != 0) {\n        return;\n    }\n\n    input += gid * args.num_vecs;\n    output += gid * args.num_vecs;\n\n    float4 sumsq4 = 0.0f;\n    for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {\n        const float4 val = input[i];\n        sumsq4 = metal::fma(val, val, sumsq4);\n    }\n\n    // Tree-reduce sumsq within thread, then all-reduce within threadgroup.\n    const float2 sumsq2 = sumsq4.xy + sumsq4.zw;\n    float sumsq = sumsq2.x + sumsq2.y;\n    // Warning: this all-reduce works only for simdgroup of 32 threads and threadgroup of 32*32=1024 threads.\n    sumsq = metal::simd_sum(sumsq);\n    if (metal::simd_is_first()) {\n        const uint simdgroup_idx = tid / simdgroup_size;\n        threadgroup_buffer[simdgroup_idx] = sumsq;\n    }\n    metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    const uint simdgroup_tid = tid % simdgroup_size;\n    sumsq = threadgroup_buffer[simdgroup_tid];\n    sumsq = metal::simd_sum(sumsq);\n\n    const float avgsq = sumsq / args.num_channels;\n    const float scale = metal::precise::rsqrt(avgsq + args.epsilon);\n    for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {\n        const float4 val = input[i] * scale;\n        const float4 weight_val = static_cast<float4>(weights[i]);\n        output[i] = val * weight_val;\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/rope.metal",
    "content": "#include <metal_common>\n#include <metal_math>\n\n#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n\n// Each thread handles 2 head elements.\n// Each simdgroup handles one head (64 head elements).\n\nkernel void gptoss_f32_rope(\n    constant gptoss_rope_args& args [[ buffer(0) ]],\n    device float2* activations [[ buffer(1) ]],\n    device float2* kv [[ buffer(2) ]],\n    const device gptoss_control* control [[ buffer(3) ]],\n    uint2 gid [[thread_position_in_grid]])\n{\n    const uint num_head_dims = 64;\n    if (control->abort != 0) {\n        return;\n    }\n\n    const float dim_idx = static_cast<float>(gid.x % (num_head_dims / 2));\n    const uint token_idx = args.token_offset + gid.y;\n    activations += gid.y * args.token_stride + gid.x;\n\n    const float2 input_vals = *activations;\n    const float inv_extrapolation_freq = metal::precise::exp(dim_idx * args.freq_scale);\n    const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;\n    const float alpha = metal::saturate(metal::fma(dim_idx, args.yarn_scale, args.yarn_offset));\n    const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);\n\n    const float phi = static_cast<float>(token_idx) * inv_freq;\n    const float yarn_multiplier = args.yarn_multiplier;\n    float cosphi;\n    const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;\n    cosphi *= yarn_multiplier;\n\n    const float output_re = input_vals.x * cosphi - input_vals.y * sinphi;\n    const float output_im = input_vals.x * sinphi + input_vals.y * cosphi;\n    *activations = (float2) { output_re, output_im };\n\n    const uint head_dim = 64;\n    const uint num_q_heads = 64;\n    const uint num_kv_heads = 8;\n    const uint head_idx = gid.x / (head_dim / 2);\n    float2 vals = (float2) { output_re, output_im };\n    if ((head_idx < num_q_heads)) {\n        *activations = vals;\n    } else if (head_idx < num_q_heads + num_kv_heads) {\n        // Write k and v directly to the kv cache.\n        const uint kv_head_idx = head_idx - num_q_heads;\n        const uint dim_pair_idx = gid.x % (head_dim / 2);\n        kv[(kv_head_idx * args.max_tokens + token_idx) * head_dim + dim_pair_idx] = vals;\n    }\n}"
  },
  {
    "path": "gpt_oss/metal/source/sample.metal",
    "content": "#include <metal_compute>\n#include <metal_integer>\n#include <metal_math>\n#include <metal_simdgroup>\n\n#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n\ninline static uint rng_squares32(ulong offset, ulong seed) {\n    const ulong y = offset * seed;\n    const ulong z = y + seed;\n\n    /* Round 1 */\n    ulong x = y * y + y;\n    x = metal::rotate(x, 32ul);\n\n    /* Round 2 */\n    x = x * x + z;\n    x = metal::rotate(x, 32ul);\n\n    /* Round 3 */\n    x = x * x + y;\n    x = metal::rotate(x, 32ul);\n\n    /* Round 4 */\n    x = x * x + z;\n    return as_type<uint2>(x).y;\n}\n\nkernel void gptoss_f32_softmax(\n    constant gptoss_softmax_args& args [[ buffer(0) ]],\n    const device float* score [[ buffer(1) ]],\n    const device uint2* argmax [[ buffer(2) ]],\n    device float* prob [[ buffer(3) ]],\n    device float* sum [[ buffer(4) ]],\n    const device gptoss_control* control [[ buffer(5) ]],\n    uint tidx [[thread_index_in_threadgroup]],\n    uint2 gid [[threadgroup_position_in_grid]],\n    uint2 threadgroup_size [[threads_per_threadgroup]],\n    uint simdgroup_tid [[thread_index_in_simdgroup]],\n    uint simdgroup_idx [[simdgroup_index_in_threadgroup]],\n    uint num_simdgroups [[simdgroups_per_threadgroup]])\n{\n    threadgroup float threadgroup_sumexp[32];\n    if (control->abort != 0) {\n        return;\n    }\n\n    score += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;\n    prob += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;\n    sum += gid.y * args.max_threadgroups;\n\n    uint max_bits = argmax[gid.y].y;\n    if (static_cast<int>(max_bits) >= 0) {\n        max_bits ^= 0x7FFFFFFFu;\n    }\n    const float max_val = as_type<float>(max_bits);\n    float sum_exp = 0.0f;\n    const uint num_vecs_per_threadgroup = metal::min(args.num_vecs - gid.x * args.num_vecs_per_threadgroup, args.num_vecs_per_threadgroup);\n    for (uint i = tidx; i < num_vecs_per_threadgroup; i += threadgroup_size.x) {\n        const float score_val = score[i];\n        const float prob_val = metal::precise::exp((score_val - max_val) * args.temperature);\n        prob[i] = prob_val;\n        sum_exp += prob_val;\n    }\n    sum_exp = metal::simd_sum(sum_exp);\n    if (metal::simd_is_first()) {\n        threadgroup_sumexp[simdgroup_idx] = sum_exp;\n    }\n    metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    if (simdgroup_idx == 0) {\n        // Sum-Reduce threadgroup_sumexp\n        sum_exp = 0.0f;\n        if (simdgroup_tid < num_simdgroups) {\n            sum_exp = threadgroup_sumexp[simdgroup_tid];\n        }\n        sum_exp = metal::simd_sum(sum_exp);\n        if (metal::simd_is_first()) {\n            sum[gid.x] = sum_exp;\n        }\n    }\n}\n\n[[max_total_threads_per_threadgroup(1024)]]\nkernel void gptoss_f32_sample(\n    constant gptoss_sample_args& args [[ buffer(0) ]],\n    device const float* prob [[ buffer(1) ]],\n    device const float* sum [[ buffer(2) ]],\n    device uint* prediction [[ buffer(3) ]],\n    device gptoss_control* control [[ buffer(4) ]],\n    uint tid [[thread_position_in_threadgroup]],\n    uint threadgroup_size [[threads_per_threadgroup]],\n    uint simdgroup_tid [[thread_index_in_simdgroup]],\n    uint simdgroup_idx [[simdgroup_index_in_threadgroup]],\n    uint num_simdgroups [[simdgroups_per_threadgroup]])\n{\n    threadgroup float threadgroup_sum_buffer[32];\n    threadgroup uint threadgroup_idx_buffer[32];\n    threadgroup float threadgroup_cumsum_buffer[32];\n    if (control->abort != 0) {\n        return;\n    }\n\n    const uint sample_word = rng_squares32(args.rng_offset, args.rng_seed);\n    float sample_cdf = static_cast<float>(sample_word & 0x00FFFFFFu) * 0x1.0p-24f;\n\n    float cumsum = 0.0f;\n    if (tid < args.num_blocks) {\n        cumsum = sum[tid];\n    }\n    cumsum = metal::simd_prefix_inclusive_sum(cumsum);\n    if (simdgroup_tid == 31) {\n        threadgroup_sum_buffer[simdgroup_idx] = cumsum;\n    }\n    metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;\n    if (simdgroup_tid < num_simdgroups) {\n        threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];\n        if (simdgroup_tid < simdgroup_idx) {\n            threadgroup_cumsum = threadgroup_sum;\n        }\n    }\n    threadgroup_sum = metal::simd_sum(threadgroup_sum);\n    cumsum += metal::simd_sum(threadgroup_cumsum);\n\n    sample_cdf *= threadgroup_sum;\n    sample_cdf = metal::max(sample_cdf, 0x1.0p-149f);\n\n    // Find the block: the smallest tid where sample_cdf >= s\n    uint block_idx = args.num_blocks;\n    float block_sum = cumsum;\n    if (tid >= args.num_blocks - 1) {\n        block_idx = args.num_blocks - 1;\n        block_sum = 0.0f;\n    } else if (cumsum >= sample_cdf) {\n        block_idx = tid;\n        block_sum = 0.0f;\n    }\n    block_idx = metal::simd_min(block_idx);\n    block_sum = metal::simd_max(block_sum);\n    if (simdgroup_tid == 0) {\n        threadgroup_idx_buffer[simdgroup_idx] = block_idx;\n        threadgroup_cumsum_buffer[simdgroup_idx] = block_sum;\n    }\n    metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n    if (simdgroup_tid < num_simdgroups) {\n        block_idx = threadgroup_idx_buffer[simdgroup_tid];\n        block_sum = threadgroup_cumsum_buffer[simdgroup_tid];\n    }\n    block_idx = metal::simd_min(block_idx);\n    block_sum = metal::simd_max(block_sum);\n\n    const uint block_start = args.num_dims_per_block * block_idx;\n    const uint block_end = metal::min(block_start + args.num_dims_per_block, args.num_dims);\n    uint offset = block_start + tid;\n    float accumulated_sum = block_sum;\n    uint sample_idx;\n\n    // This loop must be threadgroup-uniform.\n    do {\n        // Find the token: the smallest tid where sample_cdf >= s\n        float cumsum = 0.0f;\n        if (offset < block_end) {\n            cumsum = prob[offset];\n        }\n        cumsum = metal::simd_prefix_inclusive_sum(cumsum);\n        if (simdgroup_tid == 31) {\n            threadgroup_sum_buffer[simdgroup_idx] = cumsum;\n        }\n        metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n        float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;\n        if (simdgroup_tid < num_simdgroups) {\n            threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];\n            if (simdgroup_tid < simdgroup_idx) {\n                threadgroup_cumsum = threadgroup_sum;\n            }\n        }\n        threadgroup_sum = metal::simd_sum(threadgroup_sum);\n        cumsum += metal::simd_sum(threadgroup_cumsum);\n        cumsum += accumulated_sum;\n\n        sample_idx = block_end;\n        if (offset >= block_end) {\n            // Trigger loop exit, with the last token in the block being sampled if no other candidate was found.\n            sample_idx = block_end - 1;\n        } else if (cumsum >= sample_cdf) {\n            sample_idx = offset;\n        }\n        sample_idx = metal::simd_min(sample_idx);\n        if (simdgroup_tid == 0) {\n            threadgroup_idx_buffer[simdgroup_idx] = sample_idx;\n        }\n        metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n        if (simdgroup_tid < num_simdgroups) {\n            sample_idx = threadgroup_idx_buffer[simdgroup_tid];\n        }\n        sample_idx = metal::simd_min(sample_idx);\n\n        offset += threadgroup_size;\n        accumulated_sum += threadgroup_sum;\n    } while (sample_idx == block_end);\n\n    if (tid == 0) {\n        *prediction = sample_idx;\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/scatter.metal",
    "content": "#include <internal/kernel-args.h>\n#include <metal_integer>\n#include <metal_math>\n#include <metal_stdlib>\n\n// TODO(ibrahim): This is not optimal as each thread only scatters a single float4. To amortize the\n// cost of reading the expert id and offset for a token, we should let each thread scatter several\n// float4s.\nkernel void gptoss_f32_scatter_e4(\n    constant gptoss_scatter_args& args [[ buffer(0) ]],\n    const device float* in [[ buffer(1) ]],\n    const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(2) ]],\n    const device uint* __restrict__ expert_offsets [[ buffer(3) ]],\n    const device uint* __restrict__ intra_expert_offsets [[ buffer(4) ]],\n    device float* out [[ buffer(5) ]],\n    uint3 gid [[thread_position_in_grid]]) \n{\n    const uint total_tokens = args.tokens;\n    const uint active_experts_per_token = args.active_experts_per_token;\n    const uint embedding_dim = args.token_stride;\n    assert(embedding_dim % 4 == 0);\n    // Hard coded to top4 for now.\n    assert(active_experts_per_token == 4);\n    const uint row_in = gid.y;\n    if (row_in >= total_tokens) {\n        return;\n    }\n    // Consecutive threads in a tg read consecutive columns of the input.\n    const uint col_in_vec4 = gid.x;\n    const uint col_in = col_in_vec4 * 4;\n    if (col_in >= embedding_dim) {\n        return;\n    }\n    // Pointer to the piece of the input that we will copy to the top4 experts.\n    const device float4* src4 =\n        reinterpret_cast<const device float4*>(in + row_in * embedding_dim + col_in);\n\n    // Get the 4 destinations -- 4 experts.\n    const uint base = row_in * active_experts_per_token;\n    const uint expert0_id = expert_predictions[base].expert_id;\n    const uint expert1_id = expert_predictions[base + 1].expert_id;\n    const uint expert2_id = expert_predictions[base + 2].expert_id;\n    const uint expert3_id = expert_predictions[base + 3].expert_id;\n    const uint expert0_offset = expert_offsets[expert0_id];\n    const uint expert1_offset = expert_offsets[expert1_id];\n    const uint expert2_offset = expert_offsets[expert2_id];\n    const uint expert3_offset = expert_offsets[expert3_id];\n    const uint expert0_intra_expert_offset = intra_expert_offsets[base];\n    const uint expert1_intra_expert_offset = intra_expert_offsets[base + 1];\n    const uint expert2_intra_expert_offset = intra_expert_offsets[base + 2];\n    const uint expert3_intra_expert_offset = intra_expert_offsets[base + 3];\n    device float4* dst4_0 = reinterpret_cast<device float4*>(\n        out + (expert0_offset + expert0_intra_expert_offset) * embedding_dim + col_in);\n    device float4* dst4_1 = reinterpret_cast<device float4*>(\n        out + (expert1_offset + expert1_intra_expert_offset) * embedding_dim + col_in);\n    device float4* dst4_2 = reinterpret_cast<device float4*>(\n        out + (expert2_offset + expert2_intra_expert_offset) * embedding_dim + col_in);\n    device float4* dst4_3 = reinterpret_cast<device float4*>(\n        out + (expert3_offset + expert3_intra_expert_offset) * embedding_dim + col_in);\n    const float4 data = *src4;\n    *dst4_0 = data;\n    *dst4_1 = data;\n    *dst4_2 = data;\n    *dst4_3 = data;\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/sdpa.metal",
    "content": "#include <metal_geometric>\n#include <metal_integer>\n#include <metal_math>\n#include <metal_compute>\n#include <metal_simdgroup>\n\n#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n// Each threadgroup handles 8 Q heads / 1 KV head for 1 token\n\nkernel void gptoss_f32_sdpa_q8_d64(\n    constant gptoss_sdpa_args& args [[ buffer(0) ]],\n    const device float* q [[ buffer(1) ]],\n    const device float* kv [[ buffer(2) ]],\n    const device bfloat* s [[ buffer(3) ]],\n    device float* output [[ buffer(4) ]],\n    const device gptoss_control* control [[ buffer(6) ]],\n    threadgroup void* threadgroup_buffer [[ threadgroup(0) ]],\n    uint2 gid [[threadgroup_position_in_grid]],\n    uint2 tid [[thread_position_in_threadgroup]],\n    uint simdgroup_tid [[thread_index_in_simdgroup]],\n    uint simdgroup_idx [[simdgroup_index_in_threadgroup]],\n    uint num_simdgroups [[simdgroups_per_threadgroup]])\n{\n    const uint simdgroup_size = 32;\n    if (control->abort != 0) {\n        return;\n    }\n\n    const uint num_q_heads = 64;\n    const uint head_dim = 64;\n    const uint qmul = 8;\n\n    const uint token_stride = 2 * head_dim;\n\n    const uint qt = gid.x;  // Q token index\n    const uint h = gid.y;   // KV head index\n\n    q += qt * args.qkv_dim + h * (qmul * head_dim);\n    kv += h * args.kv_stride;\n    output += qt * (num_q_heads * head_dim) + h * (qmul * head_dim);\n\n    float m0 = static_cast<float>(s[h * qmul + 0]);\n    float m1 = static_cast<float>(s[h * qmul + 1]);\n    float m2 = static_cast<float>(s[h * qmul + 2]);\n    float m3 = static_cast<float>(s[h * qmul + 3]);\n    float m4 = static_cast<float>(s[h * qmul + 4]);\n    float m5 = static_cast<float>(s[h * qmul + 5]);\n    float m6 = static_cast<float>(s[h * qmul + 6]);\n    float m7 = static_cast<float>(s[h * qmul + 7]);\n\n    float l0 = simdgroup_idx == 0 ? 1.0f : 0.0f;\n    float l1 = simdgroup_idx == 0 ? 1.0f : 0.0f;\n    float l2 = simdgroup_idx == 0 ? 1.0f : 0.0f;\n    float l3 = simdgroup_idx == 0 ? 1.0f : 0.0f;\n    float l4 = simdgroup_idx == 0 ? 1.0f : 0.0f;\n    float l5 = simdgroup_idx == 0 ? 1.0f : 0.0f;\n    float l6 = simdgroup_idx == 0 ? 1.0f : 0.0f;\n    float l7 = simdgroup_idx == 0 ? 1.0f : 0.0f;\n\n    float2 out0 = 0.0f;\n    float2 out1 = 0.0f;\n    float2 out2 = 0.0f;\n    float2 out3 = 0.0f;\n    float2 out4 = 0.0f;\n    float2 out5 = 0.0f;\n    float2 out6 = 0.0f;\n    float2 out7 = 0.0f;\n\n    float2 q0 = reinterpret_cast<const device float2*>(q + 0 * head_dim)[simdgroup_tid];\n    float2 q1 = reinterpret_cast<const device float2*>(q + 1 * head_dim)[simdgroup_tid];\n    float2 q2 = reinterpret_cast<const device float2*>(q + 2 * head_dim)[simdgroup_tid];\n    float2 q3 = reinterpret_cast<const device float2*>(q + 3 * head_dim)[simdgroup_tid];\n    float2 q4 = reinterpret_cast<const device float2*>(q + 4 * head_dim)[simdgroup_tid];\n    float2 q5 = reinterpret_cast<const device float2*>(q + 5 * head_dim)[simdgroup_tid];\n    float2 q6 = reinterpret_cast<const device float2*>(q + 6 * head_dim)[simdgroup_tid];\n    float2 q7 = reinterpret_cast<const device float2*>(q + 7 * head_dim)[simdgroup_tid];\n\n    const uint kt_end = qt + args.num_kv_tokens + 1;\n    const uint kt_start = metal::subsat(kt_end, args.window) + simdgroup_idx;\n    kv += token_stride * kt_start;\n    for (uint kt = kt_start; kt < kt_end; kt += num_simdgroups) {\n        const float2 kval = reinterpret_cast<const device float2*>(kv)[simdgroup_tid];\n\n        float qk0 = metal::dot(q0, kval);\n        float qk1 = metal::dot(q1, kval);\n        float qk2 = metal::dot(q2, kval);\n        float qk3 = metal::dot(q3, kval);\n        float qk4 = metal::dot(q4, kval);\n        float qk5 = metal::dot(q5, kval);\n        float qk6 = metal::dot(q6, kval);\n        float qk7 = metal::dot(q7, kval);\n\n        qk0 = metal::simd_sum(qk0);\n        qk1 = metal::simd_sum(qk1);\n        qk2 = metal::simd_sum(qk2);\n        qk3 = metal::simd_sum(qk3);\n        qk4 = metal::simd_sum(qk4);\n        qk5 = metal::simd_sum(qk5);\n        qk6 = metal::simd_sum(qk6);\n        qk7 = metal::simd_sum(qk7);\n\n        const float new_m0 = metal::max(m0, qk0);\n        const float new_m1 = metal::max(m1, qk1);\n        const float new_m2 = metal::max(m2, qk2);\n        const float new_m3 = metal::max(m3, qk3);\n        const float new_m4 = metal::max(m4, qk4);\n        const float new_m5 = metal::max(m5, qk5);\n        const float new_m6 = metal::max(m6, qk6);\n        const float new_m7 = metal::max(m7, qk7);\n\n        const float alpha0 = metal::fast::exp(m0 - new_m0);\n        const float alpha1 = metal::fast::exp(m1 - new_m1);\n        const float alpha2 = metal::fast::exp(m2 - new_m2);\n        const float alpha3 = metal::fast::exp(m3 - new_m3);\n        const float alpha4 = metal::fast::exp(m4 - new_m4);\n        const float alpha5 = metal::fast::exp(m5 - new_m5);\n        const float alpha6 = metal::fast::exp(m6 - new_m6);\n        const float alpha7 = metal::fast::exp(m7 - new_m7);\n\n        qk0 = metal::fast::exp(qk0 - new_m0);\n        qk1 = metal::fast::exp(qk1 - new_m1);\n        qk2 = metal::fast::exp(qk2 - new_m2);\n        qk3 = metal::fast::exp(qk3 - new_m3);\n        qk4 = metal::fast::exp(qk4 - new_m4);\n        qk5 = metal::fast::exp(qk5 - new_m5);\n        qk6 = metal::fast::exp(qk6 - new_m6);\n        qk7 = metal::fast::exp(qk7 - new_m7);\n\n        l0 = metal::fma(l0, alpha0, qk0);\n        l1 = metal::fma(l1, alpha1, qk1);\n        l2 = metal::fma(l2, alpha2, qk2);\n        l3 = metal::fma(l3, alpha3, qk3);\n        l4 = metal::fma(l4, alpha4, qk4);\n        l5 = metal::fma(l5, alpha5, qk5);\n        l6 = metal::fma(l6, alpha6, qk6);\n        l7 = metal::fma(l7, alpha7, qk7);\n\n        m0 = new_m0;\n        m1 = new_m1;\n        m2 = new_m2;\n        m3 = new_m3;\n        m4 = new_m4;\n        m5 = new_m5;\n        m6 = new_m6;\n        m7 = new_m7;\n\n        const float2 vval = reinterpret_cast<const device float2*>(kv + head_dim)[simdgroup_tid];\n        kv += token_stride * num_simdgroups;\n        out0 = metal::fma(vval, qk0, out0 * alpha0);\n        out1 = metal::fma(vval, qk1, out1 * alpha1);\n        out2 = metal::fma(vval, qk2, out2 * alpha2);\n        out3 = metal::fma(vval, qk3, out3 * alpha3);\n        out4 = metal::fma(vval, qk4, out4 * alpha4);\n        out5 = metal::fma(vval, qk5, out5 * alpha5);\n        out6 = metal::fma(vval, qk6, out6 * alpha6);\n        out7 = metal::fma(vval, qk7, out7 * alpha7);\n    }\n    if (num_simdgroups > 1) {\n        if (metal::simd_is_first()) {\n            static_cast<threadgroup float*>(threadgroup_buffer)[0 * num_simdgroups + simdgroup_idx] = m0;\n            static_cast<threadgroup float*>(threadgroup_buffer)[1 * num_simdgroups + simdgroup_idx] = m1;\n            static_cast<threadgroup float*>(threadgroup_buffer)[2 * num_simdgroups + simdgroup_idx] = m2;\n            static_cast<threadgroup float*>(threadgroup_buffer)[3 * num_simdgroups + simdgroup_idx] = m3;\n            static_cast<threadgroup float*>(threadgroup_buffer)[4 * num_simdgroups + simdgroup_idx] = m4;\n            static_cast<threadgroup float*>(threadgroup_buffer)[5 * num_simdgroups + simdgroup_idx] = m5;\n            static_cast<threadgroup float*>(threadgroup_buffer)[6 * num_simdgroups + simdgroup_idx] = m6;\n            static_cast<threadgroup float*>(threadgroup_buffer)[7 * num_simdgroups + simdgroup_idx] = m7;\n\n            static_cast<threadgroup float*>(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_idx] = l0;\n            static_cast<threadgroup float*>(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_idx] = l1;\n            static_cast<threadgroup float*>(threadgroup_buffer)[10 * num_simdgroups + simdgroup_idx] = l2;\n            static_cast<threadgroup float*>(threadgroup_buffer)[11 * num_simdgroups + simdgroup_idx] = l3;\n            static_cast<threadgroup float*>(threadgroup_buffer)[12 * num_simdgroups + simdgroup_idx] = l4;\n            static_cast<threadgroup float*>(threadgroup_buffer)[13 * num_simdgroups + simdgroup_idx] = l5;\n            static_cast<threadgroup float*>(threadgroup_buffer)[14 * num_simdgroups + simdgroup_idx] = l6;\n            static_cast<threadgroup float*>(threadgroup_buffer)[15 * num_simdgroups + simdgroup_idx] = l7;\n        }\n        metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n        // Note: simdgroup refers not to the thread's current simdgroup, but to one with simdgroup_idx == thread's simdgroup_tid.\n        float simdgroup_m0 = m0;\n        float simdgroup_m1 = m1;\n        float simdgroup_m2 = m2;\n        float simdgroup_m3 = m3;\n        float simdgroup_m4 = m4;\n        float simdgroup_m5 = m5;\n        float simdgroup_m6 = m6;\n        float simdgroup_m7 = m7;\n        if (simdgroup_tid < num_simdgroups) {\n            simdgroup_m0 = static_cast<const threadgroup float*>(threadgroup_buffer)[0 * num_simdgroups + simdgroup_tid];\n            simdgroup_m1 = static_cast<const threadgroup float*>(threadgroup_buffer)[1 * num_simdgroups + simdgroup_tid];\n            simdgroup_m2 = static_cast<const threadgroup float*>(threadgroup_buffer)[2 * num_simdgroups + simdgroup_tid];\n            simdgroup_m3 = static_cast<const threadgroup float*>(threadgroup_buffer)[3 * num_simdgroups + simdgroup_tid];\n            simdgroup_m4 = static_cast<const threadgroup float*>(threadgroup_buffer)[4 * num_simdgroups + simdgroup_tid];\n            simdgroup_m5 = static_cast<const threadgroup float*>(threadgroup_buffer)[5 * num_simdgroups + simdgroup_tid];\n            simdgroup_m6 = static_cast<const threadgroup float*>(threadgroup_buffer)[6 * num_simdgroups + simdgroup_tid];\n            simdgroup_m7 = static_cast<const threadgroup float*>(threadgroup_buffer)[7 * num_simdgroups + simdgroup_tid];\n        }\n\n        const float threadgroup_m0 = metal::simd_max(simdgroup_m0);\n        const float threadgroup_m1 = metal::simd_max(simdgroup_m1);\n        const float threadgroup_m2 = metal::simd_max(simdgroup_m2);\n        const float threadgroup_m3 = metal::simd_max(simdgroup_m3);\n        const float threadgroup_m4 = metal::simd_max(simdgroup_m4);\n        const float threadgroup_m5 = metal::simd_max(simdgroup_m5);\n        const float threadgroup_m6 = metal::simd_max(simdgroup_m6);\n        const float threadgroup_m7 = metal::simd_max(simdgroup_m7);\n\n        out0 *= metal::fast::exp(m0 - threadgroup_m0);\n        out1 *= metal::fast::exp(m1 - threadgroup_m1);\n        out2 *= metal::fast::exp(m2 - threadgroup_m2);\n        out3 *= metal::fast::exp(m3 - threadgroup_m3);\n        out4 *= metal::fast::exp(m4 - threadgroup_m4);\n        out5 *= metal::fast::exp(m5 - threadgroup_m5);\n        out6 *= metal::fast::exp(m6 - threadgroup_m6);\n        out7 *= metal::fast::exp(m7 - threadgroup_m7);\n\n        if (simdgroup_idx == 0) {\n            l0 = 0.0f;\n            l1 = 0.0f;\n            l2 = 0.0f;\n            l3 = 0.0f;\n            l4 = 0.0f;\n            l5 = 0.0f;\n            l6 = 0.0f;\n            l7 = 0.0f;\n            if (simdgroup_tid < num_simdgroups) {\n                l0 = static_cast<const threadgroup float*>(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_tid];\n                l1 = static_cast<const threadgroup float*>(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_tid];\n                l2 = static_cast<const threadgroup float*>(threadgroup_buffer)[10 * num_simdgroups + simdgroup_tid];\n                l3 = static_cast<const threadgroup float*>(threadgroup_buffer)[11 * num_simdgroups + simdgroup_tid];\n                l4 = static_cast<const threadgroup float*>(threadgroup_buffer)[12 * num_simdgroups + simdgroup_tid];\n                l5 = static_cast<const threadgroup float*>(threadgroup_buffer)[13 * num_simdgroups + simdgroup_tid];\n                l6 = static_cast<const threadgroup float*>(threadgroup_buffer)[14 * num_simdgroups + simdgroup_tid];\n                l7 = static_cast<const threadgroup float*>(threadgroup_buffer)[15 * num_simdgroups + simdgroup_tid];\n            }\n\n            l0 = metal::simd_sum(l0 * metal::fast::exp(simdgroup_m0 - threadgroup_m0));\n            l1 = metal::simd_sum(l1 * metal::fast::exp(simdgroup_m1 - threadgroup_m1));\n            l2 = metal::simd_sum(l2 * metal::fast::exp(simdgroup_m2 - threadgroup_m2));\n            l3 = metal::simd_sum(l3 * metal::fast::exp(simdgroup_m3 - threadgroup_m3));\n            l4 = metal::simd_sum(l4 * metal::fast::exp(simdgroup_m4 - threadgroup_m4));\n            l5 = metal::simd_sum(l5 * metal::fast::exp(simdgroup_m5 - threadgroup_m5));\n            l6 = metal::simd_sum(l6 * metal::fast::exp(simdgroup_m6 - threadgroup_m6));\n            l7 = metal::simd_sum(l7 * metal::fast::exp(simdgroup_m7 - threadgroup_m7));\n        }\n\n        uint num_threads = num_simdgroups * simdgroup_size;\n        do {\n            const uint num_smem_threads = (num_threads / 2) & -simdgroup_size;\n            const uint num_half_threads = num_threads - num_smem_threads;\n\n            metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n            const uint smem_tid = tid.x - num_half_threads;\n            if (smem_tid < num_smem_threads) {\n                static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 0 + smem_tid] = out0;\n                static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 1 + smem_tid] = out1;\n                static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 2 + smem_tid] = out2;\n                static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 3 + smem_tid] = out3;\n                static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 4 + smem_tid] = out4;\n                static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 5 + smem_tid] = out5;\n                static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 6 + smem_tid] = out6;\n                static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 7 + smem_tid] = out7;\n            }\n            metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);\n            if (tid.x < num_smem_threads) {\n                out0 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 0 + tid.x];\n                out1 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 1 + tid.x];\n                out2 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 2 + tid.x];\n                out3 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 3 + tid.x];\n                out4 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 4 + tid.x];\n                out5 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 5 + tid.x];\n                out6 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 6 + tid.x];\n                out7 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 7 + tid.x];\n            }\n\n            num_threads = num_half_threads;\n        } while (num_threads > simdgroup_size);\n    }\n    if (simdgroup_idx == 0) {\n        reinterpret_cast<device float2*>(output + 0 * head_dim)[simdgroup_tid] = out0 / l0;\n        reinterpret_cast<device float2*>(output + 1 * head_dim)[simdgroup_tid] = out1 / l1;\n        reinterpret_cast<device float2*>(output + 2 * head_dim)[simdgroup_tid] = out2 / l2;\n        reinterpret_cast<device float2*>(output + 3 * head_dim)[simdgroup_tid] = out3 / l3;\n        reinterpret_cast<device float2*>(output + 4 * head_dim)[simdgroup_tid] = out4 / l4;\n        reinterpret_cast<device float2*>(output + 5 * head_dim)[simdgroup_tid] = out5 / l5;\n        reinterpret_cast<device float2*>(output + 6 * head_dim)[simdgroup_tid] = out6 / l6;\n        reinterpret_cast<device float2*>(output + 7 * head_dim)[simdgroup_tid] = out7 / l7;\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/tokenizer.c",
    "content": "#include <assert.h>\n#include <stdatomic.h>\n#include <stddef.h>\n#include <stdint.h>\n#include <stdlib.h>\n#include <string.h>\n\n#include <errno.h>\n#include <sys/mman.h>\n\n#include <gpt-oss.h>\n\n#include \"internal/log.h\"\n#include \"internal/model.h\"\n\n\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(\n    gptoss_tokenizer_t tokenizer,\n    enum gptoss_special_token token_type,\n    uint32_t* token_id_out)\n{\n    uint32_t token_id = UINT32_MAX;\n    if (token_type != gptoss_special_token_invalid && token_type < gptoss_special_token_max)\n    {\n        token_id = tokenizer->special_token_id[(uint32_t) token_type - 1];\n    }\n    if (token_id == UINT32_MAX) {\n        return gptoss_status_invalid_argument;\n    }\n\n    *token_id_out = token_id;\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(\n    gptoss_tokenizer_t tokenizer,\n    uint32_t* num_text_tokens_out)\n{\n    *num_text_tokens_out = tokenizer->num_text_tokens;\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(\n    gptoss_tokenizer_t tokenizer,\n    uint32_t* num_special_tokens_out)\n{\n    *num_special_tokens_out = tokenizer->num_special_tokens;\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(\n    gptoss_tokenizer_t tokenizer,\n    uint32_t* num_tokens_out)\n{\n    *num_tokens_out = tokenizer->num_text_tokens + tokenizer->num_special_tokens;\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(\n    gptoss_tokenizer_t tokenizer,\n    uint32_t token_id,\n    const void** token_ptr_out,\n    size_t* token_size_out)\n{\n    if (token_id >= tokenizer->num_text_tokens) {\n        return gptoss_status_invalid_argument;\n    }\n\n    const char* token_ptr = (const char*) tokenizer->tokens_ptr;\n    for (uint32_t t = 0; t < token_id; t++) {\n        // Reading unaligned uint16_t\n        uint16_t token_length;\n        memcpy(&token_length, token_ptr, sizeof(token_length));\n\n        token_ptr += (size_t) token_length + sizeof(uint16_t);\n    }\n\n    *token_ptr_out = (const void*) (token_ptr + sizeof(uint16_t));\n    *token_size_out = (size_t) *token_ptr;\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(\n    gptoss_tokenizer_t tokenizer)\n{\n    atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);\n    return gptoss_status_success;\n}\n\nenum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(\n    gptoss_tokenizer_t tokenizer)\n{\n    if (tokenizer != NULL) {\n        if (atomic_fetch_sub_explicit(&tokenizer->ref_count, 1, memory_order_acquire) == 1) {\n            if (tokenizer->mapping_ptr != NULL && tokenizer->mapping_size != 0) {\n                if (munmap(tokenizer->mapping_ptr, tokenizer->mapping_size) != 0) {\n                    GPTOSS_LOG_WARNING(\"munmap for tokenizer mapping failed with error %d\", errno);\n                }\n            }\n\n            memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));\n            free(tokenizer);\n        }\n    }\n    return gptoss_status_success;\n}\n"
  },
  {
    "path": "gpt_oss/metal/source/topk.metal",
    "content": "#include <metal_compute>\n#include <metal_integer>\n#include <metal_math>\n#include <metal_simdgroup>\n\n#include <internal/kernel-args.h>\n\n#pragma METAL fp math_mode(safe)\n#pragma METAL fp contract(off)\n\n\n[[max_total_threads_per_threadgroup(32)]]\nkernel void gptoss_f32_topk_softmax_e128_k4(\n    constant gptoss_topk_args& args [[ buffer(0) ]],\n    const device float4* input [[ buffer(1) ]],\n    device gptoss_expert_prediction* output [[ buffer(2) ]],\n    const device gptoss_control* control [[ buffer(3) ]],\n    uint gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_position_in_threadgroup]])\n{\n    const uint num_experts = 128;\n    const uint num_active_experts = 4;\n    if (control->abort != 0) {\n        return;\n    }\n\n    input += gid * (num_experts / 4);\n    output += gid * num_active_experts;\n\n    uint4 idx = tid * 4 + (uint4) {0, 1, 2, 3};\n    float4 val = input[tid];\n\n    const float topval0 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));\n    uint idx0 = 0xFFFFFFFFu;\n    if (val.w == topval0) {\n        idx0 = idx.w;\n    }\n    if (val.z == topval0) {\n        idx0 = idx.z;\n    }\n    if (val.y == topval0) {\n        idx0 = idx.y;\n    }\n    if (val.x == topval0) {\n        idx0 = idx.x;\n    }\n    const uint topidx0 = metal::simd_min(idx0);\n    const bool4 is_topidx0 = idx == topidx0;\n    val = metal::select(val, -INFINITY, is_topidx0);\n    idx = metal::select(idx, 0xFFFFFFFFu, is_topidx0);\n\n    const float topval1 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));\n    uint idx1 = 0xFFFFFFFFu;\n    if (val.w == topval1) {\n        idx1 = idx.w;\n    }\n    if (val.z == topval1) {\n        idx1 = idx.z;\n    }\n    if (val.y == topval1) {\n        idx1 = idx.y;\n    }\n    if (val.x == topval1) {\n        idx1 = idx.x;\n    }\n    const uint topidx1 = metal::simd_min(idx1);\n    const bool4 is_topidx1 = idx == topidx1;\n    val = metal::select(val, -INFINITY, is_topidx1);\n    idx = metal::select(idx, 0xFFFFFFFFu, is_topidx1);\n\n    const float topval2 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));\n    uint idx2 = 0xFFFFFFFFu;\n    if (val.w == topval2) {\n        idx2 = idx.w;\n    }\n    if (val.z == topval2) {\n        idx2 = idx.z;\n    }\n    if (val.y == topval2) {\n        idx2 = idx.y;\n    }\n    if (val.x == topval2) {\n        idx2 = idx.x;\n    }\n    const uint topidx2 = metal::simd_min(idx2);\n    const bool4 is_topidx2 = idx == topidx2;\n    val = metal::select(val, -INFINITY, is_topidx2);\n    idx = metal::select(idx, 0xFFFFFFFFu, is_topidx2);\n\n    const float topval3 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));\n    uint idx3 = 0xFFFFFFFFu;\n    if (val.w == topval3) {\n        idx3 = idx.w;\n    }\n    if (val.z == topval3) {\n        idx3 = idx.z;\n    }\n    if (val.y == topval3) {\n        idx3 = idx.y;\n    }\n    if (val.x == topval3) {\n        idx3 = idx.x;\n    }\n    const uint topidx3 = metal::simd_min(idx3);\n\n    if (metal::simd_is_first()) {\n        const float topexp0 = 1.0f;\n        const float topexp1 = metal::precise::exp(topval1 - topval0);\n        const float topexp2 = metal::precise::exp(topval2 - topval0);\n        const float topexp3 = metal::precise::exp(topval3 - topval0);\n\n        const float sum = (topexp0 + topexp1) + (topexp2 + topexp3);\n        const float scale = 1.0 / sum;\n\n        output[0] = (gptoss_expert_prediction) {\n            .expert_id = topidx0,\n            .score = topexp0 * scale,\n        };\n        output[1] = (gptoss_expert_prediction) {\n            .expert_id = topidx1,\n            .score = topexp1 * scale,\n        };\n        output[2] = (gptoss_expert_prediction) {\n            .expert_id = topidx2,\n            .score = topexp2 * scale,\n        };\n        output[3] = (gptoss_expert_prediction) {\n            .expert_id = topidx3,\n            .score = topexp3 * scale,\n        };\n    }\n}\n\n[[max_total_threads_per_threadgroup(32)]]\nkernel void gptoss_f32_topk_softmax_e32_k4(\n    constant gptoss_topk_args& args [[ buffer(0) ]],\n    const device float* input [[ buffer(1) ]],\n    device gptoss_expert_prediction* output [[ buffer(2) ]],\n    const device gptoss_control* control [[ buffer(3) ]],\n    uint gid [[threadgroup_position_in_grid]],\n    uint tid [[thread_position_in_threadgroup]])\n{\n    const uint num_experts = 32;\n    const uint num_active_experts = 4;\n    if (control->abort != 0) {\n        return;\n    }\n\n    input += gid * num_experts;\n    output += gid * num_active_experts;\n\n    float val = input[tid];\n    uint idx = tid;\n\n    const float topval0 = metal::simd_max(val);\n    const uint topidx0 = metal::simd_min(val == topval0 ? idx : 0xFFFFFFFFu);\n    if (idx == topidx0) {\n        val = -INFINITY;\n        idx = 0xFFFFFFFFu;\n    }\n\n    const float topval1 = metal::simd_max(val);\n    const uint topidx1 = metal::simd_min(val == topval1 ? idx : 0xFFFFFFFFu);\n    if (idx == topidx1) {\n        val = -INFINITY;\n        idx = 0xFFFFFFFFu;\n    }\n\n    const float topval2 = metal::simd_max(val);\n    const uint topidx2 = metal::simd_min(val == topval2 ? idx : 0xFFFFFFFFu);\n    if (idx == topidx2) {\n        val = -INFINITY;\n        idx = 0xFFFFFFFFu;\n    }\n\n    const float topval3 = metal::simd_max(val);\n    const uint topidx3 = metal::simd_min(val == topval3 ? idx : 0xFFFFFFFFu);\n\n    if (metal::simd_is_first()) {\n        const float topexp0 = 1.0f;\n        const float topexp1 = metal::precise::exp(topval1 - topval0);\n        const float topexp2 = metal::precise::exp(topval2 - topval0);\n        const float topexp3 = metal::precise::exp(topval3 - topval0);\n\n        const float sum = (topexp0 + topexp1) + (topexp2 + topexp3);\n        const float scale = 1.0 / sum;\n\n        output[0] = (gptoss_expert_prediction) {\n            .expert_id = topidx0,\n            .score = topexp0 * scale,\n        };\n        output[1] = (gptoss_expert_prediction) {\n            .expert_id = topidx1,\n            .score = topexp1 * scale,\n        };\n        output[2] = (gptoss_expert_prediction) {\n            .expert_id = topidx2,\n            .score = topexp2 * scale,\n        };\n        output[3] = (gptoss_expert_prediction) {\n            .expert_id = topidx3,\n            .score = topexp3 * scale,\n        };\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/test/bf16-f32-embeddings.cc",
    "content": "#include <gtest/gtest.h>\n\n#include <cstddef>\n\n#include \"embeddings-kernel-tester.hpp\"\n\n\nusing gptoss::EmbeddingsKernelTester;\n\nconstexpr std::size_t kThreadgroupSize = 64;\n\n\nTEST(BF16_F32_EMBEDDINGS, single_token_single_tile) {\n    EmbeddingsKernelTester()\n        .num_channels(kThreadgroupSize)\n        .threadgroup_size(kThreadgroupSize)\n        .TestBF16_F32();\n}\n\nTEST(BF16_F32_EMBEDDINGS, single_token_multi_tile) {\n    EmbeddingsKernelTester()\n        .num_channels(kThreadgroupSize * 4 + 16)\n        .threadgroup_size(kThreadgroupSize)\n        .TestBF16_F32();\n}\n\nTEST(BF16_F32_EMBEDDINGS, multiple_tokens) {\n    EmbeddingsKernelTester()\n        .num_channels(kThreadgroupSize * 4 + 16)\n        .num_tokens(3)\n        .threadgroup_size(kThreadgroupSize)\n        .TestBF16_F32();\n}\n"
  },
  {
    "path": "gpt_oss/metal/test/embeddings-kernel-tester.hpp",
    "content": "#pragma once\n\n#include <gtest/gtest.h>\n\n#include <cstddef>\n#include <cstdint>\n\n#include <internal/datatype.hpp>\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n\n\nnamespace gptoss {\n\nclass EmbeddingsKernelTester {\npublic:\n    EmbeddingsKernelTester() { }\n\n    EmbeddingsKernelTester(const EmbeddingsKernelTester&) = delete;\n    EmbeddingsKernelTester(EmbeddingsKernelTester&&) = delete;\n    EmbeddingsKernelTester& operator=(const EmbeddingsKernelTester&) = delete;\n    EmbeddingsKernelTester& operator=(EmbeddingsKernelTester&&) = delete;\n\n    [[nodiscard]]\n    EmbeddingsKernelTester& num_channels(std::uint32_t num_channels) {\n        num_channels_ = num_channels;\n        return *this;\n    }\n\n    std::uint32_t num_channels() const {\n        return num_channels_;\n    }\n\n    [[nodiscard]]\n    EmbeddingsKernelTester& num_tokens(std::uint32_t num_tokens) {\n        num_tokens_ = num_tokens;\n        return *this;\n    }\n\n    std::uint32_t num_tokens() const {\n        return num_tokens_;\n    }\n\n    std::uint32_t vocabulary_size() const {\n        return num_tokens() + 1;\n    }\n\n    [[nodiscard]]\n    EmbeddingsKernelTester& threadgroup_size(std::size_t threadgroup_size) {\n        threadgroup_size_ = threadgroup_size;\n        return *this;\n    }\n\n    std::size_t threadgroup_size() const {\n        return threadgroup_size_;\n    }\n\n    void Validate() const {\n        ASSERT_NE(num_channels(), 0);\n        ASSERT_NE(num_tokens(), 0);\n        ASSERT_NE(threadgroup_size(), 0);\n        ASSERT_EQ(threadgroup_size() % 32, 0);\n    }\n\n    void TestBF16_F32() const {\n        Validate();\n\n        metal::CommandBuffer command_buffer{command_queue_};\n        metal::Buffer token_buffer{device_, sizeof(std::uint32_t)};\n        metal::Buffer weight_buffer{device_, vocabulary_size() * num_channels() * sizeof(gptoss_bfloat16)};\n        metal::Buffer output_buffer{device_, num_channels() * sizeof(float)};\n        metal::Buffer control_buffer{device_, sizeof(gptoss_control)};\n        std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));\n\n        std::uint32_t* token_ptr = static_cast<std::uint32_t*>(token_buffer.ptr());\n        for (std::uint32_t t = 0; t < num_tokens(); t++) {\n            token_ptr[t] = t + 1;\n        }\n\n        Check(gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(\n                command_buffer.handle(),\n                bf16_f32_embeddings_fn.handle(),\n                threadgroup_size(),\n                token_buffer.handle(),\n                /*token_offset=*/0,\n                weight_buffer.handle(),\n                /*weight_offset=*/0,\n                output_buffer.handle(),\n                /*output_offset=*/0,\n                control_buffer.handle(),\n                /*control_offset=*/0,\n                num_tokens(),\n                num_channels()),\n            \"gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings\");\n\n        command_buffer.commit();\n        command_buffer.wait_completion();\n\n        const gptoss_bfloat16* weight_ptr = static_cast<const gptoss_bfloat16*>(weight_buffer.ptr());\n        const float* output_ptr = static_cast<const float*>(output_buffer.ptr());\n        for (std::uint32_t t = 0; t < num_tokens(); t++) {\n            const std::uint32_t token = token_ptr[t];\n            for (std::uint32_t i = 0; i < num_channels(); i++) {\n                const gptoss_bfloat16 input_val = weight_ptr[token * num_channels() + i];\n                const float ref_output = upcast<float>(input_val);\n                const float output = output_ptr[t * num_channels() + i];\n                ASSERT_EQ(output, ref_output)\n                    << \"at token \" << t << \", position \" << i << \" / \" << num_channels() << \", input \" << std::uint32_t(input_val.bits);\n            }\n        }\n    }\n\nprivate:\n    metal::Device device_{};\n    metal::CommandQueue command_queue_{device_};\n    metal::Library library_{device_};\n    metal::Function bf16_f32_embeddings_fn{library_, \"gptoss_bf16_f32_embeddings\"};\n    std::uint32_t num_tokens_{1};\n    std::uint32_t num_channels_{1};\n    std::size_t threadgroup_size_{32};\n};\n\n}  // namespace gptoss\n"
  },
  {
    "path": "gpt_oss/metal/test/f32-bf16w-matmul.cc",
    "content": "#include <gtest/gtest.h>\n\n#include <cstddef>\n#include <cstdint>\n\n#include \"matmul-kernel-tester.hpp\"\n\n\nusing gptoss::MatMulKernelTester;\n\nconstexpr size_t kSimdgroupSize = 32;  // fixed in the kernel\n\nTEST(F32_BF16W_MATMUL, single_simdgroup_single_iteration) {\n    MatMulKernelTester()\n        .num_rows(1)\n        .num_cols(kSimdgroupSize * 4)\n        .threadgroup_size(kSimdgroupSize)\n        .TestF32_BF16W();\n}\n\nTEST(F32_BF16W_MATMUL, single_simdgroup_multiple_iteration) {\n    MatMulKernelTester()\n        .num_rows(1)\n        .num_cols((2 * kSimdgroupSize + 1) * 4)\n        .threadgroup_size(kSimdgroupSize)\n        .TestF32_BF16W();\n}\n\nTEST(F32_BF16W_MATMUL, single_threadgroup) {\n    constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;\n\n    MatMulKernelTester()\n        .num_rows(threadgroup_size / kSimdgroupSize)\n        .num_cols((2 * kSimdgroupSize + 1) * 4)\n        .threadgroup_size(threadgroup_size)\n        .TestF32_BF16W();\n}\n\nTEST(F32_BF16W_MATMUL, multiple_threadgroups) {\n    constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;\n    constexpr std::uint32_t num_threadgroups = 3;\n\n    MatMulKernelTester()\n        .num_rows(num_threadgroups * threadgroup_size / kSimdgroupSize)\n        .num_cols((2 * kSimdgroupSize + 1) * 4)\n        .threadgroup_size(threadgroup_size)\n        .TestF32_BF16W();\n}\n\nTEST(F32_BF16W_MATMUL, multiple_tokens) {\n    constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;\n    constexpr std::uint32_t num_threadgroups = 3;\n\n    MatMulKernelTester()\n        .num_rows(num_threadgroups * threadgroup_size / kSimdgroupSize)\n        .num_cols((2 * kSimdgroupSize + 1) * 4)\n        .num_tokens(2)\n        .threadgroup_size(threadgroup_size)\n        .TestF32_BF16W();\n}\n\nTEST(F32_BF16W_DENSE_MATMUL_QKV, seq_len_1024) {\n    MatMulKernelTester()\n        .num_tokens(1024)\n        .num_rows(5120)\n        .num_cols(2880)\n        .TestF32_BF16W(\n            MatMulKernelTester::MatMulKernelType::PREFILL_QKV_OPTIMIZED);\n}\n\nTEST(F32_BF16W_DENSE_MATMUL_ATTN_OUTPUT, seq_len_1024) {\n    MatMulKernelTester()\n        .num_tokens(1024)\n        .num_rows(2880)\n        .num_cols(4096)\n        .TestF32_BF16W(\n            MatMulKernelTester::MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED);\n}\n\nTEST(F32_BF16W_DENSE_MATMUL_MLP_GATE, seq_len_1024) {\n    MatMulKernelTester()\n        .num_tokens(1024)\n        .num_rows(128)\n        .num_cols(2880)\n        .TestF32_BF16W(\n            MatMulKernelTester::MatMulKernelType::PREFILL_MLP_GATE_OPTIMIZED);\n}"
  },
  {
    "path": "gpt_oss/metal/test/f32-bf16w-rmsnorm.cc",
    "content": "#include <gtest/gtest.h>\n\n#include <cstdint>\n\n#include \"rmsnorm-kernel-tester.hpp\"\n\n\nusing gptoss::RMSNormKernelTester;\n\nconstexpr std::uint32_t kThreadgroupSize = 1024;  // fixed in the kernel\nconstexpr std::uint32_t kVectorSize = 4;  // fixed in the kernel\n\nTEST(F32_BF16W_RMSNORM, single_iteration) {\n    RMSNormKernelTester()\n        .num_channels(kThreadgroupSize)\n        .TestF32_BF16W();\n}\n\nTEST(F32_BF16W_RMSNORM, multiple_iterations) {\n    RMSNormKernelTester()\n        .num_channels(kThreadgroupSize * 2)\n        .TestF32_BF16W();\n}\n\nTEST(F32_BF16W_RMSNORM, partial_iteration) {\n    RMSNormKernelTester()\n        .num_channels(kThreadgroupSize * 2 + kVectorSize)\n        .TestF32_BF16W();\n}\n\nTEST(F32_BF16W_RMSNORM, multiple_tokens) {\n    RMSNormKernelTester()\n        .num_tokens(3)\n        .num_channels(kThreadgroupSize * 2 + kVectorSize)\n        .TestF32_BF16W();\n}\n"
  },
  {
    "path": "gpt_oss/metal/test/f32-random.cc",
    "content": "#include <gtest/gtest.h>\n\n#include <cmath>\n\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n#include <internal/rng.hpp>\n\nusing gptoss::Check;\nusing namespace gptoss::metal;\n\n\nconstexpr uint64_t kSeed = UINT64_C(1019827666124465388);\nconstexpr uint64_t kOffset = UINT64_C(12345678901234567890);\nconstexpr float kMin = -1.0f;\nconstexpr float kMax = +1.5f;\nconstexpr float kScale = (kMax - kMin) * 0.5f;\nconstexpr float kBias = (kMin + kMax) * 0.5f;\nconstexpr size_t kThreadgroupSize = 128;\n\nTEST(F32_FILL_RANDOM, single_threadgroup_single_iteration) {\n    constexpr size_t num_bytes = kThreadgroupSize * 16;\n    constexpr size_t num_elements = num_bytes / sizeof(uint32_t);\n\n    Device device;\n    CommandQueue command_queue{device};\n    CommandBuffer command_buffer{command_queue};\n    Library library{device};\n    Function f32_fill_random_fn{library, \"gptoss_f32_fill_random\"};\n    Buffer buffer{device, num_elements * sizeof(float)};\n\n    Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n            command_buffer.handle(),\n            f32_fill_random_fn.handle(),\n            /*threadgroup_size=*/kThreadgroupSize,\n            /*max_threadgroups=*/1,\n            /*output_buffer=*/buffer.handle(),\n            /*output_offset=*/0,\n            num_elements, kSeed, kOffset, kMin, kMax),\n        \"gptoss_metal_command_buffer_encode_launch_f32_fill_random\");\n\n    command_buffer.commit();\n    command_buffer.wait_completion();\n\n    const float* output_ptr = static_cast<const float*>(buffer.ptr());\n    for (size_t i = 0; i < num_elements; i++) {\n        const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);\n        const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;\n        const float ref_value = std::fma(ref_float, kScale, kBias);\n        ASSERT_EQ(output_ptr[i], ref_value)\n            << \"at position \" << i << \" / \" << num_elements;\n    }\n}\n\nTEST(F32_FILL_RANDOM, single_threadgroup_multiple_iterations) {\n    constexpr size_t num_iterations = 3;\n    constexpr size_t num_bytes = num_iterations * kThreadgroupSize * 16;\n    constexpr size_t num_elements = num_bytes / sizeof(uint32_t);\n\n    Device device;\n    CommandQueue command_queue{device};\n    CommandBuffer command_buffer{command_queue};\n    Library library{device};\n    Function f32_fill_random_fn{library, \"gptoss_f32_fill_random\"};\n    Buffer buffer{device, num_elements * sizeof(float)};\n\n    Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n            command_buffer.handle(),\n            f32_fill_random_fn.handle(),\n            /*threadgroup_size=*/kThreadgroupSize,\n            /*max_threadgroups=*/1,\n            /*output_buffer=*/buffer.handle(),\n            /*output_offset=*/0,\n            num_elements, kSeed, kOffset, kMin, kMax),\n        \"gptoss_metal_command_buffer_encode_launch_f32_fill_random\");\n\n    command_buffer.commit();\n    command_buffer.wait_completion();\n\n    const float* output_ptr = static_cast<const float*>(buffer.ptr());\n    for (size_t i = 0; i < num_elements; i++) {\n        const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);\n        const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;\n        const float ref_value = std::fma(ref_float, kScale, kBias);\n        ASSERT_EQ(output_ptr[i], ref_value)\n            << \"at position \" << i << \" / \" << num_elements;\n    }\n}\n\nTEST(F32_FILL_RANDOM, multiple_threadgroups_multiple_iterations) {\n    constexpr size_t num_iterations = 3;\n    constexpr size_t num_threadgroups = 2;\n    constexpr size_t num_bytes = num_iterations * num_threadgroups * kThreadgroupSize * 16;\n    constexpr size_t num_elements = num_bytes / sizeof(uint32_t);\n\n    Device device;\n    CommandQueue command_queue{device};\n    CommandBuffer command_buffer{command_queue};\n    Library library{device};\n    Function f32_fill_random_fn{library, \"gptoss_f32_fill_random\"};\n    Buffer buffer{device, num_elements * sizeof(float)};\n\n    Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n            command_buffer.handle(),\n            f32_fill_random_fn.handle(),\n            /*threadgroup_size=*/kThreadgroupSize,\n            /*max_threadgroups=*/num_threadgroups,\n            /*output_buffer=*/buffer.handle(),\n            /*output_offset=*/0,\n            num_elements, kSeed, kOffset, kMin, kMax),\n        \"gptoss_metal_command_buffer_encode_launch_f32_fill_random\");\n\n    command_buffer.commit();\n    command_buffer.wait_completion();\n\n    const float* output_ptr = static_cast<const float*>(buffer.ptr());\n    for (size_t i = 0; i < num_elements; i++) {\n        const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);\n        const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;\n        const float ref_value = std::fma(ref_float, kScale, kBias);\n        ASSERT_EQ(output_ptr[i], ref_value)\n            << \"at position \" << i << \" / \" << num_elements;\n    }\n}\n\nTEST(F32_FILL_RANDOM, excessive_threadgroups) {\n    constexpr size_t num_bytes = kThreadgroupSize * 16;\n    constexpr size_t num_elements = num_bytes / sizeof(uint32_t);\n\n    Device device;\n    CommandQueue command_queue{device};\n    CommandBuffer command_buffer{command_queue};\n    Library library{device};\n    Function f32_fill_random_fn{library, \"gptoss_f32_fill_random\"};\n    Buffer buffer{device, num_elements * sizeof(float)};\n\n    Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n            command_buffer.handle(),\n            f32_fill_random_fn.handle(),\n            /*threadgroup_size=*/kThreadgroupSize,\n            /*max_threadgroups=*/2,\n            /*output_buffer=*/buffer.handle(),\n            /*output_offset=*/0,\n            num_elements, kSeed, kOffset, kMin, kMax),\n        \"gptoss_metal_command_buffer_encode_launch_f32_fill_random\");\n\n    command_buffer.commit();\n    command_buffer.wait_completion();\n\n    const float* output_ptr = static_cast<const float*>(buffer.ptr());\n    for (size_t i = 0; i < num_elements; i++) {\n        const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);\n        const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;\n        const float ref_value = std::fma(ref_float, kScale, kBias);\n        ASSERT_EQ(output_ptr[i], ref_value)\n            << \"at position \" << i << \" / \" << num_elements;\n    }\n}\n\nTEST(F32_FILL_RANDOM, nonuniform_range) {\n    constexpr size_t num_iterations = 3;\n    constexpr size_t num_threadgroups = 2;\n    constexpr size_t num_bytes = (num_iterations * num_threadgroups + 1) * kThreadgroupSize * 16;\n    constexpr size_t num_elements = num_bytes / sizeof(uint32_t);\n\n    Device device;\n    CommandQueue command_queue{device};\n    CommandBuffer command_buffer{command_queue};\n    Library library{device};\n    Function f32_fill_random_fn{library, \"gptoss_f32_fill_random\"};\n    Buffer buffer{device, num_elements * sizeof(float)};\n\n    Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n            command_buffer.handle(),\n            f32_fill_random_fn.handle(),\n            /*threadgroup_size=*/kThreadgroupSize,\n            /*max_threadgroups=*/num_threadgroups,\n            /*output_buffer=*/buffer.handle(),\n            /*output_offset=*/0,\n            num_elements, kSeed, kOffset, kMin, kMax),\n        \"gptoss_metal_command_buffer_encode_launch_f32_fill_random\");\n\n    command_buffer.commit();\n    command_buffer.wait_completion();\n\n    const float* output_ptr = static_cast<const float*>(buffer.ptr());\n    for (size_t i = 0; i < num_elements; i++) {\n        const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);\n        const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;\n        const float ref_value = std::fma(ref_float, kScale, kBias);\n        ASSERT_EQ(output_ptr[i], ref_value)\n            << \"at position \" << i << \" / \" << num_elements;\n    }\n}\n\nTEST(F32_FILL_RANDOM, partial_range) {\n    constexpr size_t num_iterations = 3;\n    constexpr size_t num_threadgroups = 2;\n    constexpr size_t num_bytes = (num_iterations * num_threadgroups * kThreadgroupSize + 1) * 16;\n    constexpr size_t num_elements = num_bytes / sizeof(uint32_t);\n\n    Device device;\n    CommandQueue command_queue{device};\n    CommandBuffer command_buffer{command_queue};\n    Library library{device};\n    Function f32_fill_random_fn{library, \"gptoss_f32_fill_random\"};\n    Buffer buffer{device, num_elements * sizeof(float)};\n\n    Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(\n            command_buffer.handle(),\n            f32_fill_random_fn.handle(),\n            /*threadgroup_size=*/kThreadgroupSize,\n            /*max_threadgroups=*/num_threadgroups,\n            /*output_buffer=*/buffer.handle(),\n            /*output_offset=*/0,\n            num_elements, kSeed, kOffset, kMin, kMax),\n        \"gptoss_metal_command_buffer_encode_launch_f32_fill_random\");\n\n    command_buffer.commit();\n    command_buffer.wait_completion();\n\n    const float* output_ptr = static_cast<const float*>(buffer.ptr());\n    for (size_t i = 0; i < num_elements; i++) {\n        const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);\n        const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;\n        const float ref_value = std::fma(ref_float, kScale, kBias);\n        ASSERT_EQ(output_ptr[i], ref_value)\n            << \"at position \" << i << \" / \" << num_elements;\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/test/f32-rope.cc",
    "content": "#include <gtest/gtest.h>\n\n#include <cstddef>\n#include <cstdint>\n\n#include \"rope-kernel-tester.hpp\"\n\n\nusing gptoss::RoPEKernelTester;\n\nconstexpr float kFrequencyBase = 50000.0f;\nconstexpr std::uint32_t kHeadDim = 64;  // fixed in the kernel\nconstexpr std::uint32_t kTokenOffset = 7;\n\n\nTEST(F32_ROPE, single_simdgroup) {\n    RoPEKernelTester()\n        .head_dim(kHeadDim)\n        .num_q_heads(1)\n        .num_kv_heads(0)\n        .token_offset(kTokenOffset)\n        .frequency_base(kFrequencyBase)\n        .threadgroup_size(kHeadDim / 2)\n        .TestF32();\n}\n\nTEST(F32_ROPE, single_threadgroup) {\n    constexpr std::size_t threadgroup_size = 64;\n    constexpr std::uint32_t num_heads = threadgroup_size / (kHeadDim / 2);\n\n    RoPEKernelTester()\n        .head_dim(kHeadDim)\n        .num_q_heads(num_heads)\n        .num_kv_heads(0)\n        .token_offset(kTokenOffset)\n        .frequency_base(kFrequencyBase)\n        .threadgroup_size(threadgroup_size)\n        .TestF32();\n}\n\nTEST(F32_ROPE, multiple_threadgroups) {\n    constexpr std::uint32_t num_threadgroups = 3;\n    constexpr std::size_t threadgroup_size = 64;\n    constexpr std::uint32_t num_heads = num_threadgroups * (threadgroup_size / (kHeadDim / 2));\n\n    RoPEKernelTester()\n        .head_dim(kHeadDim)\n        .num_q_heads(num_heads)\n        .num_kv_heads(0)\n        .token_offset(kTokenOffset)\n        .frequency_base(kFrequencyBase)\n        .threadgroup_size(threadgroup_size)\n        .TestF32();\n}\n\nTEST(F32_ROPE, multiple_tokens) {\n    constexpr std::uint32_t num_tokens = 2;\n    constexpr std::uint32_t num_threadgroups = 3;\n    constexpr std::size_t threadgroup_size = 64;\n    constexpr std::uint32_t num_heads = num_threadgroups * (threadgroup_size / (kHeadDim / 2));\n\n    RoPEKernelTester()\n        .head_dim(kHeadDim)\n        .num_tokens(2)\n        .num_q_heads(num_heads)\n        .num_kv_heads(0)\n        .token_offset(kTokenOffset)\n        .frequency_base(kFrequencyBase)\n        .threadgroup_size(threadgroup_size)\n        .TestF32();\n}\n"
  },
  {
    "path": "gpt_oss/metal/test/fill-random-kernel-tester.hpp",
    "content": "#pragma once\n\n#include <gtest/gtest.h>\n\n#include <cstddef>\n#include <cstdint>\n\n#include <internal/datatype.hpp>\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n#include <internal/rng.hpp>\n\n\nnamespace gptoss {\n\nclass FillRandomKernelTester {\npublic:\n    FillRandomKernelTester() { }\n\n    FillRandomKernelTester(const FillRandomKernelTester&) = delete;\n    FillRandomKernelTester(FillRandomKernelTester&&) = delete;\n    FillRandomKernelTester& operator=(const FillRandomKernelTester&) = delete;\n    FillRandomKernelTester& operator=(FillRandomKernelTester&&) = delete;\n\n    [[nodiscard]]\n    FillRandomKernelTester& num_elements(std::uint32_t num_elements) {\n        num_elements_ = num_elements;\n        return *this;\n    }\n\n    std::uint32_t num_elements() const {\n        return num_elements_;\n    }\n\n    [[nodiscard]]\n    FillRandomKernelTester& threadgroup_size(std::size_t threadgroup_size) {\n        threadgroup_size_ = threadgroup_size;\n        return *this;\n    }\n\n    std::size_t threadgroup_size() const {\n        return threadgroup_size_;\n    }\n\n    [[nodiscard]]\n    FillRandomKernelTester& max_threadgroups(std::size_t max_threadgroups) {\n        max_threadgroups_ = max_threadgroups;\n        return *this;\n    }\n\n    std::size_t max_threadgroups() const {\n        return max_threadgroups_;\n    }\n\n    void Validate() const {\n        ASSERT_NE(num_elements(), 0);\n        ASSERT_NE(threadgroup_size(), 0);\n        ASSERT_NE(max_threadgroups(), 0);\n    }\n\n    void TestU32() const {\n        Validate();\n\n        metal::Buffer output_buffer{device_, num_elements() * sizeof(std::uint32_t)};\n\n        metal::CommandBuffer command_buffer{command_queue_};\n        command_buffer.encode_launch_u32_fill_random(\n            u32_fill_random_fn_,\n            threadgroup_size(),\n            max_threadgroups(),\n            output_buffer,\n            /*output_offset=*/0,\n            num_elements(), kSeed, kOffset);\n\n        command_buffer.commit();\n        command_buffer.wait_completion();\n\n        const std::uint32_t* output_ptr = static_cast<const std::uint32_t*>(output_buffer.ptr());\n        for (std::size_t i = 0; i < num_elements(); i++) {\n            const std::uint32_t ref_value = gptoss::rng::squares32(kOffset + i, kSeed);\n            ASSERT_EQ(output_ptr[i], ref_value)\n                << \"at position \" << i << \" / \" << num_elements();\n        }\n    }\n\nprivate:\n    static constexpr uint64_t kSeed{UINT64_C(1019827666124465388)};\n    static constexpr uint64_t kOffset{UINT64_C(12345678901234567890)};\n\n    metal::Device device_{};\n    metal::CommandQueue command_queue_{device_};\n    metal::Library library_{device_};\n    metal::Function f32_fill_random_fn_{library_, \"gptoss_f32_fill_random\"};\n    metal::Function bf16_fill_random_fn_{library_, \"gptoss_bf16_fill_random\"};\n    metal::Function u32_fill_random_fn_{library_, \"gptoss_u32_fill_random\"};\n    std::uint32_t num_elements_{1};\n    std::size_t threadgroup_size_{32};\n    std::size_t max_threadgroups_{1};\n};\n\n}  // namespace gptoss\n"
  },
  {
    "path": "gpt_oss/metal/test/matmul-kernel-tester.hpp",
    "content": "#pragma once\n\n#include <gtest/gtest.h>\n\n#include <cmath>\n#include <cstddef>\n#include <cstdint>\n\n#include <internal/datatype.hpp>\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n\nnamespace gptoss {\n\ntemplate <typename T>\n::testing::AssertionResult\nIsNearAbsRel(const char* a_expr, const char* b_expr, const char* abs_expr,\n             const char* rel_expr, T a, T b, T abs_tol, T rel_tol = 1.0) {\n\n    using std::abs;\n    if (!std::isfinite(a) || !std::isfinite(b)) {\n        return ::testing::AssertionFailure()\n               << \"Non-finite value(s): \" << a_expr << \"=\" << a << \", \" << b_expr\n               << \"=\" << b;\n        // At least one of abs_tol and rel_tol must be provided\n    }\n    const T diff = abs(a - b);\n    const T rel = rel_tol * std::max(abs(a), abs(b));\n    const T thr = std::max(abs_tol, rel);\n\n    if (diff <= thr)\n        return ::testing::AssertionSuccess();\n\n    return ::testing::AssertionFailure()\n           << a_expr << \" vs \" << b_expr << \" differ by \" << diff\n           << \" > max(abs_tol=\" << abs_tol << \", rel_tol*max(|a|,|b|)=\" << rel\n           << \") with \" << abs_expr << \"=\" << abs_tol << \", \" << rel_expr << \"=\"\n           << rel_tol << \". \\n\"\n           << a_expr << \"=\" << a << \". \\n\"\n           << b_expr << \"=\" << b;\n}\n\n#define ASSERT_NEAR_ABS_REL(a, b, abs_tol, rel_tol) \\\n    ASSERT_PRED_FORMAT4(IsNearAbsRel<double>, a, b, abs_tol, rel_tol)\n\nclass MatMulKernelTester {\npublic:\n    MatMulKernelTester() { }\n\n    MatMulKernelTester(const MatMulKernelTester&) = delete;\n    MatMulKernelTester(MatMulKernelTester&&) = delete;\n    MatMulKernelTester& operator=(const MatMulKernelTester&) = delete;\n    MatMulKernelTester& operator=(MatMulKernelTester&&) = delete;\n\n    [[nodiscard]]\n    MatMulKernelTester& num_rows(std::uint32_t num_rows) {\n        num_rows_ = num_rows;\n        return *this;\n    }\n\n    std::uint32_t num_rows() const {\n        return num_rows_;\n    }\n\n    [[nodiscard]]\n    MatMulKernelTester& num_cols(std::uint32_t num_cols) {\n        num_cols_ = num_cols;\n        return *this;\n    }\n\n    std::uint32_t num_cols() const {\n        return num_cols_;\n    }\n\n    [[nodiscard]]\n    MatMulKernelTester& num_tokens(std::uint32_t num_tokens) {\n        num_tokens_ = num_tokens;\n        return *this;\n    }\n\n    std::uint32_t num_tokens() const {\n        return num_tokens_;\n    }\n\n    [[nodiscard]]\n    MatMulKernelTester& threadgroup_size(std::size_t threadgroup_size) {\n        threadgroup_size_ = threadgroup_size;\n        return *this;\n    }\n\n    std::size_t threadgroup_size() const {\n        return threadgroup_size_;\n    }\n\n    void Validate(std::uint32_t vec_size) const {\n        ASSERT_NE(num_rows(), 0);\n        ASSERT_NE(num_cols(), 0);\n        ASSERT_EQ(num_cols() % vec_size, 0);\n        ASSERT_NE(num_tokens(), 0);\n        ASSERT_NE(threadgroup_size(), 0);\n    }\n\n    enum class MatMulKernelType {\n        DECODE_OPTIMIZED,\n        PREFILL_QKV_OPTIMIZED,\n        PREFILL_ATTN_OUTPUT_OPTIMIZED,\n        PREFILL_MLP_GATE_OPTIMIZED,\n    };\n\n    void TestF32_BF16W(MatMulKernelType kernel_type = MatMulKernelType::DECODE_OPTIMIZED) const {\n        Validate(/*vec_size=*/4);\n\n        metal::CommandBuffer command_buffer_initialize{command_queue_};\n        metal::Buffer input_buffer{device_, num_tokens() * num_cols() * sizeof(float)};\n        metal::Buffer weight_buffer{device_, num_rows() * num_cols() * sizeof(gptoss_bfloat16)};\n        metal::Buffer bias_buffer{device_, num_rows() * sizeof(gptoss_bfloat16)};\n        metal::Buffer output_buffer{device_, num_tokens() * num_rows() * sizeof(float)};\n        metal::Buffer output_buffer_copy{device_, num_tokens() * num_rows() * sizeof(float)};\n        // KV cache buffer for PREFILL_QKV_OPTIMIZED: assume head_dim=64, num_kv_heads=8\n        const std::uint32_t kHeadDim = 64;\n        const std::uint32_t kNumKvHeads = 8;\n        metal::Buffer kv_cache_buffer{device_, static_cast<std::size_t>(kNumKvHeads) * num_tokens() * 2 * kHeadDim * sizeof(float)};\n        metal::Buffer control_buffer{device_, sizeof(gptoss_control)};\n        std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));\n\n        command_buffer_initialize.encode_launch_f32_fill_random(\n            f32_fill_random_fn_,\n            /*threadgroup_size=*/0,\n            /*max_threadgroups=*/kFillRandomMaxThreadgroups,\n            /*output_buffer=*/input_buffer,\n            /*output_offset=*/0,\n            num_tokens() * num_cols(), kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);\n\n        command_buffer_initialize.encode_launch_bf16_fill_random(\n            bf16_fill_random_fn_,\n            /*threadgroup_size=*/0,\n            /*max_threadgroups=*/kFillRandomMaxThreadgroups,\n            /*output_buffer=*/weight_buffer,\n            /*output_offset=*/0,\n            num_rows() * num_cols(), kSeed + 1, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);\n\n        command_buffer_initialize.encode_launch_bf16_fill_random(\n            bf16_fill_random_fn_,\n            /*threadgroup_size=*/0,\n            /*max_threadgroups=*/kFillRandomMaxThreadgroups,\n            /*output_buffer=*/bias_buffer,\n            /*output_offset=*/0,\n            num_rows(), kSeed + 2, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);\n\n        // Fill output buffer with random values to test matmul with add.\n        command_buffer_initialize.encode_launch_f32_fill_random(\n            f32_fill_random_fn_,\n            /*threadgroup_size=*/0,\n            /*max_threadgroups=*/kFillRandomMaxThreadgroups,\n            /*output_buffer=*/output_buffer,\n            /*output_offset=*/0, num_tokens() * num_rows(), kSeed + 3,\n            /*offset=*/0,\n            /*min=*/-1.0f, /*max=*/1.0);\n        command_buffer_initialize.commit();\n        command_buffer_initialize.wait_completion();\n        if (kernel_type ==\n            MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED) {\n            // Copy output buffer to output buffer copy to use when calculating reference.\n            const uint64_t bytes =\n                uint64_t(num_tokens()) * uint64_t(num_rows()) * sizeof(float);\n\n            void* src = output_buffer.ptr();\n            void* dst = output_buffer_copy.ptr();\n            assert(src && dst && \"Buffers must be CPU-mappable for memcpy\");\n\n            std::memcpy(reinterpret_cast<std::byte*>(dst),\n                        reinterpret_cast<const std::byte*>(src), bytes);\n        }\n\n        metal::CommandBuffer command_buffer_compute{command_queue_};\n        switch (kernel_type) {\n        case MatMulKernelType::DECODE_OPTIMIZED:\n            Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(\n                      command_buffer_compute.handle(), f32_bf16w_matmul_fn_.handle(),\n                      /*threadgroup_size=*/threadgroup_size(), input_buffer.handle(),\n                      /*input_offset=*/0, weight_buffer.handle(),\n                      /*weight_offset=*/0, bias_buffer.handle(),\n                      /*bias_offset=*/0, output_buffer.handle(),\n                      /*output_offset=*/0, control_buffer.handle(),\n                      /*control_offset=*/0, num_tokens(), num_cols(), num_rows()),\n                  \"gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul\");\n            break;\n        case MatMulKernelType::PREFILL_QKV_OPTIMIZED:\n            Check(\n                gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(\n                    command_buffer_compute.handle(),\n                    f32_bf16w_dense_matmul_qkv_fn_.handle(), input_buffer.handle(),\n                    /*input_offset=*/0, weight_buffer.handle(),\n                    /*weight_offset=*/0, bias_buffer.handle(),\n                    /*bias_offset=*/0, output_buffer.handle(),\n                    /*output_offset=*/0, kv_cache_buffer.handle(),\n                    /*kv_offset=*/0, control_buffer.handle(),\n                    /*control_offset=*/0, num_tokens(), num_cols(), num_rows(),\n                    /*max_tokens=*/num_tokens(), /*token_offset=*/0),\n                \"gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv\");\n            break;\n        case MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED:\n            Check(\n                gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(\n                    command_buffer_compute.handle(),\n                    f32_bf16w_dense_matmul_attn_output_fn_.handle(),\n                    input_buffer.handle(),\n                    /*input_offset=*/0, weight_buffer.handle(),\n                    /*weight_offset=*/0, bias_buffer.handle(),\n                    /*bias_offset=*/0, output_buffer.handle(),\n                    /*output_offset=*/0, control_buffer.handle(),\n                    /*control_offset=*/0, num_tokens(), num_cols(), num_rows()),\n                \"gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output\");\n            break;\n        case MatMulKernelType::PREFILL_MLP_GATE_OPTIMIZED:\n            Check(\n                gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(\n                    command_buffer_compute.handle(),\n                    f32_bf16w_dense_matmul_mlp_gate_fn_.handle(),\n                    input_buffer.handle(),\n                    /*input_offset=*/0, weight_buffer.handle(),\n                    /*weight_offset=*/0, bias_buffer.handle(),\n                    /*bias_offset=*/0, output_buffer.handle(),\n                    /*output_offset=*/0, control_buffer.handle(),\n                    /*control_offset=*/0, num_tokens(), num_cols(), num_rows()),\n                \"gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate\");\n            break;\n        }\n        command_buffer_compute.commit();\n        command_buffer_compute.wait_completion();\n        const float* input_ptr = static_cast<const float*>(input_buffer.ptr());\n        const gptoss_bfloat16* weight_ptr = static_cast<const gptoss_bfloat16*>(weight_buffer.ptr());\n        const gptoss_bfloat16* bias_ptr = static_cast<const gptoss_bfloat16*>(bias_buffer.ptr());\n        const float* output_ptr = static_cast<const float*>(output_buffer.ptr());\n        const float* kv_ptr = static_cast<const float*>(kv_cache_buffer.ptr());\n        const float* output_ptr_copy = static_cast<const float*>(output_buffer_copy.ptr());\n        for (size_t t = 0; t < num_tokens(); t++) {\n            for (size_t r = 0; r < num_rows(); r++) {\n                double ref_sum = upcast<double>(bias_ptr[r]);\n                for (size_t c = 0; c < num_cols(); c++) {\n                    const double ref_weight = upcast<double>(weight_ptr[r * num_cols() + c]);\n                    const double input_value = upcast<double>(input_ptr[t * num_cols() + c]);\n                    ref_sum = std::fma(input_value, ref_weight, ref_sum);\n                }\n\n                if (kernel_type ==\n                    MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED) {\n                    ref_sum += upcast<double>(output_ptr_copy[t * num_rows() + r]);\n                }\n                if (kernel_type == MatMulKernelType::PREFILL_QKV_OPTIMIZED) {\n                    // In this optimized path, V rows are written to the kv cache at index 1.\n                    // Assume num_q_heads=64, num_kv_heads=8, head_dim=64 and QKV packed as [Q][K][V].\n                    const std::size_t v_rows_start = (64 + 8) * 64; // rows offset where V begins\n                    if (r >= v_rows_start) {\n                        const std::size_t v_row_index = r - v_rows_start;\n                        const std::size_t kv_head = v_row_index / kHeadDim;\n                        const std::size_t d = v_row_index % kHeadDim;\n                        const std::size_t kv_base = ((kv_head * num_tokens() + t) * 2 + 1) * kHeadDim;\n                        ASSERT_NEAR_ABS_REL(upcast<double>(kv_ptr[kv_base + d]), ref_sum, 2.0e-4, 1.0e-4)\n                            << \"token \" << t << \", v-row \" << r;\n                        continue;\n                    }\n                }\n                ASSERT_NEAR_ABS_REL(upcast<double>(output_ptr[t * num_rows() + r]),\n                                    ref_sum, 2.0e-4, 1.0e-4)\n                    << \"token \" << t;\n            }\n        }\n    }\n\nprivate:\n    static constexpr std::uint64_t kSeed{UINT64_C(1019827666124465388)};\n    static constexpr std::size_t kFillRandomMaxThreadgroups = 10;\n    static constexpr float fp4e2m1_to_fp32[16] = {\n        +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,\n        -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f,\n    };\n\n    metal::Device device_{};\n    metal::CommandQueue command_queue_{device_};\n    metal::Library library_{device_};\n    metal::Function f32_fill_random_fn_{library_, \"gptoss_f32_fill_random\"};\n    metal::Function bf16_fill_random_fn_{library_, \"gptoss_bf16_fill_random\"};\n    metal::Function f32_bf16w_matmul_fn_{library_, \"gptoss_f32_bf16w_matmul\"};\n    metal::Function f32_bf16w_dense_matmul_qkv_fn_{library_, \"gptoss_f32_bf16w_dense_matmul_qkv\"};\n    metal::Function f32_bf16w_dense_matmul_attn_output_fn_{library_, \"gptoss_f32_bf16w_dense_matmul_attn_output\"};\n    metal::Function f32_bf16w_dense_matmul_mlp_gate_fn_{library_, \"gptoss_f32_bf16w_dense_matmul_mlp_gate\"};\n    std::uint32_t num_tokens_{1};\n    std::uint32_t num_rows_{1};\n    std::uint32_t num_cols_{32};\n    std::size_t threadgroup_size_{32};\n};\n\n}  // namespace gptoss\n"
  },
  {
    "path": "gpt_oss/metal/test/mf4-f32-convert.cc",
    "content": "#include <gtest/gtest.h>\n\n#include <cmath>\n#include <ios>\n\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n\nusing gptoss::Check;\nusing namespace gptoss::metal;\n\nconstexpr size_t kThreadgroupSize = 32;\n\n\nstatic float fp4e2m1_to_fp32[16] = {\n    +0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,\n    -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f,\n};\n\nTEST(MF4_F32_CONVERT, single_threadgroup_single_iteration) {\n    constexpr size_t num_blocks = kThreadgroupSize;\n    constexpr size_t num_elements = num_blocks * 32;\n    constexpr size_t num_bytes = num_elements / 2;\n\n    Device device;\n    CommandQueue command_queue{device};\n    CommandBuffer command_buffer{command_queue};\n    Library library{device};\n    Function mf4_f32_convert_fn{library, \"gptoss_mf4_f32_convert\"};\n    Buffer block_buffer{device, num_bytes};\n    Buffer scale_buffer{device, num_blocks * sizeof(uint8_t)};\n    Buffer output_buffer{device, num_elements * sizeof(float)};\n\n    uint8_t* block_ptr = static_cast<uint8_t*>(block_buffer.ptr());\n    std::memset(block_ptr, 0, num_bytes);\n    for (size_t b = 0; b < num_blocks; b++) {\n        for (size_t i = 0; i < 32; i++) {\n            const uint8_t nibble = (i + b) & 0x0F;\n            const uint8_t byte = nibble << ((i % 2) * 4);\n            block_ptr[b * 16 + i / 2] |= byte;\n        }\n    }\n\n    uint8_t* scale_ptr = static_cast<uint8_t*>(scale_buffer.ptr());\n    for (size_t b = 0; b < num_blocks; b++) {\n        scale_ptr[b] = 127 - b;\n    }\n\n    Check(gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(\n            command_buffer.handle(),\n            mf4_f32_convert_fn.handle(),\n            /*threadgroup_size=*/kThreadgroupSize,\n            /*max_threadgroups=*/1,\n            block_buffer.handle(),\n            scale_buffer.handle(),\n            output_buffer.handle(),\n            num_elements),\n        \"gptoss_metal_command_buffer_encode_launch_mf4_f32_convert\");\n\n    command_buffer.commit();\n    command_buffer.wait_completion();\n\n    const float* output_ptr = static_cast<const float*>(output_buffer.ptr());\n    for (size_t b = 0; b < num_blocks; b++) {\n        for (size_t i = 0; i < 32; i++) {\n            const uint8_t byte = block_ptr[b * 16 + i / 2];\n            const uint8_t nibble = (byte >> ((i % 2) * 4)) & 0x0F;\n            const float ref_scale = std::ldexp(1.0f, static_cast<int>(scale_ptr[b]) - 127);\n            const float ref_value = fp4e2m1_to_fp32[nibble] * ref_scale;\n            ASSERT_EQ(output_ptr[b * 32 + i], ref_value)\n                << \"at position \" << i << \" / 32\"\n                << \", block \" << b << \" / \" << num_blocks\n                << \", FP4e2m1 value \" << std::hex << uint32_t(nibble);\n        }\n    }\n}\n\nTEST(MF4_F32_CONVERT, multiple_threadgroups_multiple_iterations) {\n    constexpr size_t num_threadgroups = 2;\n    constexpr size_t num_blocks = num_threadgroups * (kThreadgroupSize + 1);\n    constexpr size_t num_elements = num_blocks * 32;\n    constexpr size_t num_bytes = num_elements / 2;\n\n    Device device;\n    CommandQueue command_queue{device};\n    CommandBuffer command_buffer{command_queue};\n    Library library{device};\n    Function mf4_f32_convert_fn{library, \"gptoss_mf4_f32_convert\"};\n    Buffer block_buffer{device, num_bytes};\n    Buffer scale_buffer{device, num_blocks * sizeof(uint8_t)};\n    Buffer output_buffer{device, num_elements * sizeof(float)};\n\n    uint8_t* block_ptr = static_cast<uint8_t*>(block_buffer.ptr());\n    std::memset(block_ptr, 0, num_bytes);\n    for (size_t b = 0; b < num_blocks; b++) {\n        for (size_t i = 0; i < 32; i++) {\n            const uint8_t nibble = (i + b) & 0x0F;\n            const uint8_t byte = nibble << ((i % 2) * 4);\n            block_ptr[b * 16 + i / 2] |= byte;\n        }\n    }\n\n    uint8_t* scale_ptr = static_cast<uint8_t*>(scale_buffer.ptr());\n    for (size_t b = 0; b < num_blocks; b++) {\n        scale_ptr[b] = 200 - b;\n    }\n\n    Check(gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(\n            command_buffer.handle(),\n            mf4_f32_convert_fn.handle(),\n            /*threadgroup_size=*/kThreadgroupSize,\n            /*max_threadgroups=*/num_threadgroups,\n            block_buffer.handle(),\n            scale_buffer.handle(),\n            output_buffer.handle(),\n            num_elements),\n        \"gptoss_metal_command_buffer_encode_launch_mf4_f32_convert\");\n\n    command_buffer.commit();\n    command_buffer.wait_completion();\n\n    const float* output_ptr = static_cast<const float*>(output_buffer.ptr());\n    for (size_t b = 0; b < num_blocks; b++) {\n        for (size_t i = 0; i < 32; i++) {\n            const uint8_t byte = block_ptr[b * 16 + i / 2];\n            const uint8_t nibble = (byte >> ((i % 2) * 4)) & 0x0F;\n            const float ref_scale = std::ldexp(1.0f, static_cast<int>(scale_ptr[b]) - 127);\n            const float ref_value = fp4e2m1_to_fp32[nibble] * ref_scale;\n            ASSERT_EQ(output_ptr[b * 32 + i], ref_value)\n                << \"at position \" << i << \" / 32\"\n                << \", block \" << b << \" / \" << num_blocks\n                << \", FP4e2m1 value \" << std::hex << uint32_t(nibble);\n        }\n    }\n}\n"
  },
  {
    "path": "gpt_oss/metal/test/rmsnorm-kernel-tester.hpp",
    "content": "#pragma once\n\n#include <gtest/gtest.h>\n\n#include <cmath>\n#include <cstddef>\n#include <cstdint>\n\n#include <internal/datatype.hpp>\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n\n\nnamespace gptoss {\n\nclass RMSNormKernelTester {\npublic:\n    RMSNormKernelTester() { }\n\n    RMSNormKernelTester(const RMSNormKernelTester&) = delete;\n    RMSNormKernelTester(RMSNormKernelTester&&) = delete;\n    RMSNormKernelTester& operator=(const RMSNormKernelTester&) = delete;\n    RMSNormKernelTester& operator=(RMSNormKernelTester&&) = delete;\n\n    [[nodiscard]]\n    RMSNormKernelTester& num_channels(std::uint32_t num_channels) {\n        num_channels_ = num_channels;\n        return *this;\n    }\n\n    std::uint32_t num_channels() const {\n        return num_channels_;\n    }\n\n    [[nodiscard]]\n    RMSNormKernelTester& num_tokens(std::uint32_t num_tokens) {\n        num_tokens_ = num_tokens;\n        return *this;\n    }\n\n    std::uint32_t num_tokens() const {\n        return num_tokens_;\n    }\n\n    [[nodiscard]]\n    RMSNormKernelTester& epsilon(float epsilon) {\n        epsilon_ = epsilon;\n        return *this;\n    }\n\n    float epsilon() const {\n        return epsilon_;\n    }\n\n    void Validate() const {\n        ASSERT_NE(num_channels(), 0);\n        ASSERT_NE(num_tokens(), 0);\n        ASSERT_GE(epsilon(), 0.0f);\n    }\n\n    void TestF32_BF16W() const {\n        Validate();\n\n        metal::Buffer input_buffer{device_, num_tokens() * num_channels() * sizeof(float)};\n        metal::Buffer weight_buffer{device_, num_channels() * sizeof(gptoss_bfloat16)};\n        metal::Buffer output_buffer{device_, num_tokens() * num_channels() * sizeof(float)};\n        metal::Buffer control_buffer{device_, sizeof(gptoss_control)};\n        std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));\n\n        metal::CommandBuffer command_buffer{command_queue_};\n\n        command_buffer.encode_launch_f32_fill_random(\n            f32_fill_random_fn_,\n            /*threadgroup_size=*/0,\n            /*max_threadgroups=*/kFillRandomMaxThreadgroups,\n            /*output_buffer=*/input_buffer, /*output_offset=*/0,\n            num_channels(), kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);\n\n        command_buffer.encode_launch_bf16_fill_random(\n            bf16_fill_random_fn_,\n            /*threadgroup_size=*/0,\n            /*max_threadgroups=*/kFillRandomMaxThreadgroups,\n            /*output_buffer=*/weight_buffer, /*output_offset=*/0,\n            num_channels(), kSeed + 1, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);\n\n        Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(\n                command_buffer.handle(),\n                f32_bf16w_rmsnorm_fn_.handle(),\n                input_buffer.handle(),\n                /*input_offset=*/0,\n                weight_buffer.handle(),\n                /*weight_offset=*/0,\n                output_buffer.handle(),\n                /*output_offset=*/0,\n                control_buffer.handle(),\n                /*control_offset=*/0,\n                num_tokens(),\n                num_channels(),\n                epsilon()),\n            \"gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm\");\n\n        command_buffer.commit();\n        command_buffer.wait_completion();\n\n        const float* input_ptr = static_cast<const float*>(input_buffer.ptr());\n        const gptoss_bfloat16* weight_ptr = static_cast<const gptoss_bfloat16*>(weight_buffer.ptr());\n        const float* output_ptr = static_cast<const float*>(output_buffer.ptr());\n        for (std::uint32_t t = 0; t < num_tokens(); t++) {\n            double sumsq = 0.0;\n            for (std::uint32_t c = 0; c < num_channels(); c++) {\n                const double val = static_cast<double>(input_ptr[t * num_channels() + c]);\n                sumsq = std::fma(val, val, sumsq);\n            }\n            const double avgsq = sumsq / static_cast<double>(num_channels());\n            const double scale = 1.0 / std::sqrt(avgsq + epsilon());\n            for (std::uint32_t c = 0; c < num_channels(); c++) {\n                const double input_val = upcast<double>(input_ptr[t * num_channels() + c]);\n                const double weight_val = upcast<double>(weight_ptr[c]);\n                const double ref_output = scale * input_val * weight_val;\n                const double output = upcast<double>(output_ptr[t * num_channels() + c]);\n                ASSERT_NEAR(output, ref_output, 1.0e-5 * std::abs(ref_output))\n                    << \"at channel \" << c << \" / \" << num_channels() << \", token \" << t << \" / \" << num_tokens()\n                    << \", input \" << input_val << \", weight \" << weight_val << \", scale \" << scale;\n            }\n        }\n    }\n\nprivate:\n    static constexpr std::uint64_t kSeed{UINT64_C(1019827666124465388)};\n    static constexpr std::size_t kFillRandomMaxThreadgroups = 10;\n\n    metal::Device device_{};\n    metal::CommandQueue command_queue_{device_};\n    metal::Library library_{device_};\n    metal::Function f32_fill_random_fn_{library_, \"gptoss_f32_fill_random\"};\n    metal::Function bf16_fill_random_fn_{library_, \"gptoss_bf16_fill_random\"};\n    metal::Function f32_bf16w_rmsnorm_fn_{library_, \"gptoss_f32_bf16w_rmsnorm\"};\n    std::uint32_t num_tokens_{1};\n    std::uint32_t num_channels_{1};\n    float epsilon_{1.0e-5f};\n};\n\n}  // namespace gptoss\n"
  },
  {
    "path": "gpt_oss/metal/test/rope-kernel-tester.hpp",
    "content": "#pragma once\n\n#include <gtest/gtest.h>\n\n#include <cmath>\n#include <cstddef>\n#include <cstdint>\n\n#include <internal/datatype.hpp>\n#include <internal/metal.hpp>\n#include <internal/metal-kernels.h>\n\n\nnamespace gptoss {\n\nclass RoPEKernelTester {\npublic:\n    RoPEKernelTester() { }\n\n    RoPEKernelTester(const RoPEKernelTester&) = delete;\n    RoPEKernelTester(RoPEKernelTester&&) = delete;\n    RoPEKernelTester& operator=(const RoPEKernelTester&) = delete;\n    RoPEKernelTester& operator=(RoPEKernelTester&&) = delete;\n\n    [[nodiscard]]\n    RoPEKernelTester& threadgroup_size(std::size_t threadgroup_size) {\n        threadgroup_size_ = threadgroup_size;\n        return *this;\n    }\n\n    std::size_t threadgroup_size() const {\n        return threadgroup_size_;\n    }\n\n    [[nodiscard]]\n    RoPEKernelTester& head_dim(std::uint32_t head_dim) {\n        head_dim_ = head_dim;\n        return *this;\n    }\n\n    std::uint32_t head_dim() const {\n        return head_dim_;\n    }\n\n    [[nodiscard]]\n    RoPEKernelTester& num_q_heads(std::uint32_t num_q_heads) {\n        num_q_heads_ = num_q_heads;\n        return *this;\n    }\n\n    std::uint32_t num_q_heads() const {\n        return num_q_heads_;\n    }\n\n    [[nodiscard]]\n    RoPEKernelTester& num_kv_heads(std::uint32_t num_kv_heads) {\n        num_kv_heads_ = num_kv_heads;\n        return *this;\n    }\n\n    std::uint32_t num_kv_heads() const {\n        return num_kv_heads_;\n    }\n\n    std::uint32_t num_qk_heads() const {\n        return num_q_heads() + num_kv_heads();\n    }\n\n    std::uint32_t num_qkv_heads() const {\n        return num_q_heads() + 2 * num_kv_heads();\n    }\n\n    [[nodiscard]]\n    RoPEKernelTester& num_tokens(std::uint32_t num_tokens) {\n        num_tokens_ = num_tokens;\n        return *this;\n    }\n\n    std::uint32_t num_tokens() const {\n        return num_tokens_;\n    }\n\n    [[nodiscard]]\n    RoPEKernelTester& token_offset(std::uint32_t token_offset) {\n        token_offset_ = token_offset;\n        return *this;\n    }\n\n    std::uint32_t token_offset() const {\n        return token_offset_;\n    }\n\n    [[nodiscard]]\n    RoPEKernelTester& frequency_base(float frequency_base) {\n        frequency_base_ = frequency_base;\n        return *this;\n    }\n\n    float frequency_base() const {\n        return frequency_base_;\n    }\n\n    void Validate() const {\n        ASSERT_NE(head_dim(), 0);\n        ASSERT_EQ(head_dim() % 2, 0);\n        ASSERT_NE(num_q_heads(), 0);\n        ASSERT_NE(num_tokens(), 0);\n    }\n\n    void TestF32() const {\n        Validate();\n\n        metal::Buffer activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)};\n        metal::Buffer ref_activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)};\n        // KV cache buffer layout: [num_kv_heads][max_tokens][2 (K,V)][head_dim]\n        const std::uint32_t max_tokens = num_tokens();\n        const std::uint32_t kv_heads_for_alloc = std::max<std::uint32_t>(1, num_kv_heads());\n        const std::size_t kv_bytes = static_cast<std::size_t>(kv_heads_for_alloc) * max_tokens * 2 * head_dim() * sizeof(float);\n        metal::Buffer kv_cache_buffer{device_, kv_bytes};\n        metal::Buffer control_buffer{device_, sizeof(gptoss_control)};\n        std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));\n\n        metal::CommandBuffer command_buffer{command_queue_};\n\n        command_buffer.encode_launch_f32_fill_random(\n            f32_fill_random_fn_,\n            /*threadgroup_size=*/0,\n            /*max_threadgroups=*/kFillRandomMaxThreadgroups,\n            /*output_buffer=*/activations_buffer,\n            /*output_offset=*/0,\n            (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim(),\n            kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);\n\n        command_buffer.encode_launch_f32_fill_random(\n            f32_fill_random_fn_,\n            /*threadgroup_size=*/0,\n            /*max_threadgroups=*/kFillRandomMaxThreadgroups,\n            /*output_buffer=*/ref_activations_buffer,\n            /*output_offset=*/0,\n            (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim(),\n            kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);\n\n        Check(gptoss_metal_command_buffer_encode_launch_f32_rope(\n                command_buffer.handle(),\n                f32_rope_fn_.handle(),\n                threadgroup_size(),\n                activations_buffer.handle(),\n                /*activations_offset=*/0,\n                kv_cache_buffer.handle(),\n                /*kv_offset=*/0,\n                control_buffer.handle(),\n                /*control_offset=*/0,\n                frequency_base(),\n                /*interpolation_scale=*/1.0f,\n                /*yarn_offset=*/0.0f,\n                /*yarn_scale=*/1.0f,\n                /*yarn_multiplier=*/1.0f,\n                /*num_tokens=*/num_tokens(),\n                /*num_q_heads=*/num_q_heads(),\n                /*num_kv_heads=*/num_kv_heads(),\n                head_dim(),\n                /*max_tokens=*/max_tokens,\n                /*token_offset=*/token_offset()),\n            \"gptoss_metal_command_buffer_encode_launch_f32_rope\");\n\n        command_buffer.commit();\n        command_buffer.wait_completion();\n\n        const float* ref_activations_ptr = static_cast<const float*>(ref_activations_buffer.ptr());\n        const float* activations_ptr = static_cast<const float*>(activations_buffer.ptr());\n        const float* kv_ptr = static_cast<const float*>(kv_cache_buffer.ptr());\n        for (std::uint32_t t = 0; t < num_tokens(); t++) {\n            // Validate rotated Q written in-place in activations\n            for (std::uint32_t h = 0; h < num_q_heads(); h++) {\n                for (std::uint32_t d = 0; d < head_dim(); d += 2) {\n                    const double inv_freq = 1.0 /\n                        std::pow(static_cast<double>(frequency_base()), static_cast<double>(d) / static_cast<double>(head_dim()));\n                    const double phi = static_cast<double>(t + token_offset()) * inv_freq;\n                    const double cos_phi = std::cos(phi);\n                    const double sin_phi = std::sin(phi);\n                    const double real = static_cast<double>(ref_activations_ptr[(t * num_qkv_heads() + h) * head_dim() + d]);\n                    const double imag = static_cast<double>(ref_activations_ptr[(t * num_qkv_heads() + h) * head_dim() + d + 1]);\n                    const double ref_real = real * cos_phi - imag * sin_phi;\n                    const double ref_imag = real * sin_phi + imag * cos_phi;\n                    ASSERT_NEAR(\n                            static_cast<double>(activations_ptr[(t * num_qkv_heads() + h) * head_dim() + d]),\n                            ref_real,\n                            std::abs(ref_real) * 1.0e-4)\n                        << \"at token \" << t << \" / \" << num_tokens();\n                    ASSERT_NEAR(\n                            static_cast<double>(activations_ptr[(t * num_qkv_heads() + h) * head_dim() + d + 1]),\n                            ref_imag,\n                            std::abs(ref_imag) * 1.0e-4)\n                        << \"at token \" << t << \" / \" << num_tokens();\n\n                }\n            }\n        }\n    }\n\nprivate:\n    static constexpr uint64_t kSeed{UINT64_C(1019827666124465388)};\n    static constexpr std::size_t kFillRandomMaxThreadgroups = 10;\n\n    metal::Device device_{};\n    metal::CommandQueue command_queue_{device_};\n    metal::Library library_{device_};\n    metal::Function f32_fill_random_fn_{library_, \"gptoss_f32_fill_random\"};\n    metal::Function f32_rope_fn_{library_, \"gptoss_f32_rope\"};\n    std::size_t threadgroup_size_{32};\n    std::uint32_t head_dim_{64};\n    std::uint32_t num_q_heads_{1};\n    std::uint32_t num_kv_heads_{1};\n    std::uint32_t num_tokens_{1};\n    std::uint32_t token_offset_{0};\n    float frequency_base_{50000.0f};\n};\n\n}  // namespace gptoss\n"
  },
  {
    "path": "gpt_oss/metal/test/u32-random.cc",
    "content": "#include <gtest/gtest.h>\n\n#include <cstddef>\n#include <cstdint>\n\n#include \"fill-random-kernel-tester.hpp\"\n\n\nusing gptoss::FillRandomKernelTester;\n\nconstexpr std::size_t kThreadgroupSize = 128;\n\nTEST(U32_FILL_RANDOM, single_threadgroup_single_iteration) {\n    FillRandomKernelTester()\n        .num_elements(kThreadgroupSize)\n        .threadgroup_size(kThreadgroupSize)\n        .max_threadgroups(1)\n        .TestU32();\n}\n\nTEST(U32_FILL_RANDOM, single_threadgroup_multiple_iterations) {\n    constexpr std::size_t num_iterations = 3;\n\n    FillRandomKernelTester()\n        .num_elements(num_iterations * kThreadgroupSize)\n        .threadgroup_size(kThreadgroupSize)\n        .max_threadgroups(1)\n        .TestU32();\n}\n\nTEST(U32_FILL_RANDOM, multiple_threadgroups_multiple_iterations) {\n    constexpr std::size_t num_iterations = 3;\n    constexpr std::size_t num_threadgroups = 2;\n\n    FillRandomKernelTester()\n        .num_elements(num_iterations * num_threadgroups * kThreadgroupSize)\n        .threadgroup_size(kThreadgroupSize)\n        .max_threadgroups(num_threadgroups)\n        .TestU32();\n}\n\nTEST(U32_FILL_RANDOM, excessive_threadgroups) {\n    FillRandomKernelTester()\n        .num_elements(kThreadgroupSize)\n        .threadgroup_size(kThreadgroupSize)\n        .max_threadgroups(2)\n        .TestU32();\n}\n\nTEST(U32_FILL_RANDOM, nonuniform_range) {\n    constexpr std::size_t num_iterations = 3;\n    constexpr std::size_t num_threadgroups = 2;\n\n    FillRandomKernelTester()\n        .num_elements((num_iterations * num_threadgroups + 1) * kThreadgroupSize)\n        .threadgroup_size(kThreadgroupSize)\n        .max_threadgroups(num_threadgroups)\n        .TestU32();\n}\n\nTEST(U32_FILL_RANDOM, partial_range) {\n    constexpr std::size_t num_iterations = 3;\n    constexpr std::size_t num_threadgroups = 2;\n\n    FillRandomKernelTester()\n        .num_elements(num_iterations * num_threadgroups * kThreadgroupSize + 1)\n        .threadgroup_size(kThreadgroupSize)\n        .max_threadgroups(num_threadgroups)\n        .TestU32();\n}\n"
  },
  {
    "path": "gpt_oss/responses_api/__init__.py",
    "content": ""
  },
  {
    "path": "gpt_oss/responses_api/api_server.py",
    "content": "import os\nimport datetime\nimport uuid\nfrom typing import Callable, Literal, Optional, Union\n\nfrom fastapi import FastAPI, Request\nfrom fastapi.exception_handlers import request_validation_exception_handler\nfrom fastapi.exceptions import RequestValidationError\nfrom fastapi.responses import StreamingResponse\nfrom openai_harmony import (\n    Author,\n    Conversation,\n    DeveloperContent,\n    HarmonyEncoding,\n    Message,\n    ReasoningEffort,\n    Role,\n    StreamableParser,\n    StreamState,\n    SystemContent,\n    ToolDescription,\n)\n\nfrom gpt_oss.tools.python_docker.docker_tool import PythonTool\nfrom gpt_oss.tools.simple_browser import SimpleBrowserTool\nfrom gpt_oss.tools.simple_browser.backend import YouComBackend, ExaBackend\n\nfrom .events import (\n    ResponseCodeInterpreterCallCodeDelta,\n    ResponseCodeInterpreterCallCodeDone,\n    ResponseCodeInterpreterCallCompleted,\n    ResponseCodeInterpreterCallInProgress,\n    ResponseCodeInterpreterCallInterpreting,\n    ResponseCompletedEvent,\n    ResponseContentPartAdded,\n    ResponseContentPartDone,\n    ResponseCreatedEvent,\n    ResponseEvent,\n    ResponseInProgressEvent,\n    ResponseOutputItemAdded,\n    ResponseOutputItemDone,\n    ResponseOutputTextAnnotationAdded,\n    ResponseOutputTextDelta,\n    ResponseOutputTextDone,\n    ResponseReasoningTextDelta,\n    ResponseReasoningTextDone,\n    ResponseWebSearchCallCompleted,\n    ResponseWebSearchCallInProgress,\n    ResponseWebSearchCallSearching,\n)\nfrom .types import (\n    CodeInterpreterCallItem,\n    CodeInterpreterOutputImage,\n    CodeInterpreterOutputLogs,\n    Error,\n    FunctionCallItem,\n    Item,\n    ReasoningItem,\n    ReasoningTextContentItem,\n    ResponseObject,\n    ResponsesRequest,\n    TextContentItem,\n    UrlCitation,\n    Usage,\n    WebSearchActionFind,\n    WebSearchActionOpenPage,\n    WebSearchActionSearch,\n    WebSearchCallItem,\n)\n\nDEFAULT_TEMPERATURE = 0.0\n\n\ndef get_reasoning_effort(\n    effort: Union[Literal[\"low\", \"medium\", \"high\"], ReasoningEffort]\n) -> ReasoningEffort:\n    if isinstance(effort, ReasoningEffort):\n        return effort\n    if effort == \"low\":\n        return ReasoningEffort.LOW\n    if effort == \"medium\":\n        return ReasoningEffort.MEDIUM\n    if effort == \"high\":\n        return ReasoningEffort.HIGH\n    raise ValueError(f\"Invalid reasoning effort: {effort}\")\n\n\ndef is_not_builtin_tool(\n    recipient: str, treat_functions_python_as_builtin: bool = False\n) -> bool:\n    if treat_functions_python_as_builtin and recipient == \"functions.python\":\n        return False\n    return (\n        not recipient.startswith(\"browser.\")\n        and recipient != \"python\"\n        and recipient != \"assistant\"\n    )\n\n\ndef create_api_server(\n    infer_next_token: Callable[[list[int], float], int], encoding: HarmonyEncoding\n) -> FastAPI:\n    app = FastAPI()\n\n    @app.exception_handler(RequestValidationError)\n    async def log_validation_error(request: Request, exc: RequestValidationError):\n        try:\n            body_bytes = await request.body()\n            print(\n                \"Invalid request body received:\"\n                f\" {body_bytes.decode('utf-8', errors='replace')}\"\n            )\n        except Exception as body_exc:\n            print(f\"Failed to read invalid request body: {body_exc}\")\n        return await request_validation_exception_handler(request, exc)\n    responses_store: dict[str, tuple[ResponsesRequest, ResponseObject]] = {}\n\n    def generate_response(\n        input_tokens: list[int],\n        output_tokens: list[int],\n        request_body: ResponsesRequest,\n        debug_mode: bool = False,\n        function_call_ids: Optional[list[tuple[str, str]]] = None,\n        response_id: Optional[str] = None,\n        previous_response_id: Optional[str] = None,\n        browser_tool: Optional[SimpleBrowserTool] = None,\n        browser_call_ids: Optional[list[str]] = None,\n        python_tool: Optional[PythonTool] = None,\n        python_call_ids: Optional[list[str]] = None,\n        python_call_outputs: Optional[\n            dict[str, list[CodeInterpreterOutputLogs | CodeInterpreterOutputImage]]\n        ] = None,\n        reasoning_ids: Optional[list[str]] = None,\n        message_ids: Optional[list[str]] = None,\n        treat_functions_python_as_builtin: bool = False,\n    ) -> ResponseObject:\n        output = []\n        error = None\n        if len(output_tokens) > 0:\n            if debug_mode:\n                try:\n                    entries = encoding.parse_messages_from_completion_tokens(\n                        output_tokens, Role.ASSISTANT\n                    )\n                except Exception as e:\n                    print(f\"Error parsing tokens: {e}\")\n                    error = Error(\n                        code=\"invalid_function_call\",\n                        message=f\"{e}\",\n                    )\n                    entries = []\n            else:\n                entries = encoding.parse_messages_from_completion_tokens(\n                    output_tokens, Role.ASSISTANT\n                )\n\n            fc_index = 0\n            browser_tool_index = 0\n            python_tool_index = 0\n            reasoning_ids_iter = iter(reasoning_ids or [])\n            message_ids_iter = iter(message_ids or [])\n\n            for entry in entries:\n                entry_dict = entry.to_dict()\n                recipient = entry_dict.get(\"recipient\", \"\")\n                if len(recipient) > 0 and is_not_builtin_tool(\n                    recipient, treat_functions_python_as_builtin\n                ):\n                    call = entry_dict[\"content\"][0]\n                    arguments = call[\"text\"]\n                    name = recipient\n\n                    if name.startswith(\"functions.\"):\n                        name = name[len(\"functions.\") :]\n                    if function_call_ids and fc_index < len(function_call_ids):\n                        fc_id, call_id = function_call_ids[fc_index]\n                    else:\n                        fc_id, call_id = (\n                            f\"fc_{uuid.uuid4().hex}\",\n                            f\"call_{uuid.uuid4().hex}\",\n                        )\n                    fc_index += 1\n                    output.append(\n                        FunctionCallItem(\n                            type=\"function_call\",\n                            name=name,\n                            arguments=arguments,\n                            id=fc_id,\n                            call_id=call_id,\n                        )\n                    )\n                elif (\n                    len(recipient) > 0\n                    and recipient.startswith(\"browser.\")\n                    and browser_tool is not None\n                ):\n                    # Mirror event-based creation of WebSearchCallItems when the browser tool is invoked\n                    name = recipient\n                    call = entry_dict[\"content\"][0]\n                    arguments = call[\"text\"]\n                    function_name = name[len(\"browser.\") :]\n\n                    # Reconstruct a Message for argument parsing\n                    tool_msg = (\n                        Message.from_role_and_content(Role.ASSISTANT, arguments)\n                        .with_recipient(name)\n                        .with_channel(\"analysis\")\n                    )\n\n                    action = None\n                    try:\n                        parsed_args = browser_tool.process_arguments(tool_msg)\n                        if function_name == \"search\":\n                            action = WebSearchActionSearch(\n                                type=\"search\",\n                                query=parsed_args[\"query\"],\n                            )\n                        elif function_name == \"open\":\n                            action = WebSearchActionOpenPage(\n                                type=\"open_page\",\n                                url=parsed_args[\"url\"],\n                            )\n                        elif function_name == \"find\":\n                            action = WebSearchActionFind(\n                                type=\"find\",\n                                pattern=parsed_args[\"pattern\"],\n                                url=parsed_args[\"url\"],\n                            )\n                    except Exception as e:\n                        print(f\"Error processing browser tool arguments: {e}\")\n                        action = None\n\n                    if action is not None:\n                        if browser_call_ids and browser_tool_index < len(\n                            browser_call_ids\n                        ):\n                            web_search_call_id = browser_call_ids[browser_tool_index]\n                        else:\n                            web_search_call_id = f\"ws_{uuid.uuid4().hex}\"\n                        browser_tool_index += 1\n                        output.append(\n                            WebSearchCallItem(\n                                type=\"web_search_call\",\n                                id=web_search_call_id,\n                                action=action,\n                            )\n                        )\n                elif (\n                    len(recipient) > 0\n                    and (\n                        recipient.startswith(\"python\")\n                        or (\n                            treat_functions_python_as_builtin\n                            and recipient == \"functions.python\"\n                        )\n                    )\n                    and python_tool is not None\n                ):\n                    if python_call_ids and python_tool_index < len(python_call_ids):\n                        code_call_id = python_call_ids[python_tool_index]\n                    else:\n                        code_call_id = f\"ci_{uuid.uuid4().hex}\"\n                    python_tool_index += 1\n                    code_snippet = None\n                    if entry_dict.get(\"content\"):\n                        code_snippet = entry_dict[\"content\"][0].get(\"text\")\n                    outputs = (\n                        (python_call_outputs or {}).get(code_call_id)\n                        if python_call_outputs\n                        else None\n                    )\n                    output.append(\n                        CodeInterpreterCallItem(\n                            type=\"code_interpreter_call\",\n                            id=code_call_id,\n                            status=\"completed\",\n                            code=code_snippet,\n                            outputs=outputs,\n                        )\n                    )\n                elif entry_dict[\"channel\"] == \"final\":\n                    content = []\n                    for content_entry in entry_dict[\"content\"]:\n                        if browser_tool:\n                            text_content, annotation_entries, _has_partial_citations = (\n                                browser_tool.normalize_citations(content_entry[\"text\"])\n                            )\n                            annotations = [UrlCitation(**a) for a in annotation_entries]\n                        else:\n                            text_content = content_entry[\"text\"]\n                            annotations = []\n\n                        content.append(\n                            TextContentItem(\n                                type=\"output_text\",\n                                text=text_content,\n                                annotations=annotations,\n                            )\n                        )\n\n                    message_id = next(message_ids_iter, None)\n                    output.append(\n                        Item(\n                            id=message_id,\n                            type=\"message\",\n                            role=\"assistant\",\n                            content=content,\n                            status=\"completed\",\n                        )\n                    )\n                elif entry_dict[\"channel\"] == \"analysis\":\n                    if entry_dict.get(\"recipient\"):\n                        continue\n                    author_dict = entry_dict.get(\"author\") or {}\n                    if author_dict.get(\"role\") and author_dict.get(\"role\") != \"assistant\":\n                        continue\n                    summary = []\n                    content = [\n                        ReasoningTextContentItem(\n                            type=\"reasoning_text\",\n                            text=entry[\"text\"],\n                        )\n                        for entry in entry_dict[\"content\"]\n                    ]\n                    reasoning_id = next(reasoning_ids_iter, None)\n                    if reasoning_id is None:\n                        reasoning_id = f\"rs_{uuid.uuid4().hex}\"\n                    output.append(\n                        ReasoningItem(\n                            id=reasoning_id,\n                            type=\"reasoning\",\n                            summary=summary,\n                            content=content,\n                        )\n                    )\n        else:\n            output = []\n\n        usage = (\n            Usage(\n                input_tokens=len(input_tokens),\n                output_tokens=len(output_tokens),\n                total_tokens=len(input_tokens) + len(output_tokens),\n            )\n            if len(output_tokens) > 0\n            else None\n        )\n\n        try:\n            debug_str = encoding.decode_utf8(input_tokens + output_tokens)\n        except Exception:\n            debug_str = input_tokens + output_tokens\n        try:\n            debug_input_str = encoding.decode_utf8(input_tokens)\n        except Exception:\n            debug_input_str = input_tokens\n        try:\n            debug_output_str = encoding.decode_utf8(output_tokens)\n        except Exception:\n            debug_output_str = output_tokens\n\n        metadata = (\n            {\n                \"__debug\": debug_str,\n                \"__debug_input\": debug_input_str,\n                \"__debug_output\": debug_output_str,\n            }\n            if debug_mode\n            else {}\n        )\n\n        return ResponseObject(\n            created_at=int(datetime.datetime.now().timestamp()),\n            status=\"completed\",\n            output=output,\n            text={\"format\": {\"type\": \"text\"}},\n            usage=usage,\n            max_output_tokens=request_body.max_output_tokens,\n            error=error,\n            metadata=metadata,\n            id=response_id,\n            previous_response_id=previous_response_id,\n        )\n\n    class StreamResponsesEvents:\n        BROWSER_RESERVED_FUNCTIONS = {\"browser.search\", \"browser.open\", \"browser.find\"}\n        initial_tokens: list[int]\n        tokens: list[int]\n        output_tokens: list[int]\n        output_text: str\n        request_body: ResponsesRequest\n        request: Request\n        sequence_number: int\n\n        def __init__(\n            self,\n            initial_tokens,\n            request_body: ResponsesRequest,\n            as_sse: bool = False,\n            request: Optional[Request] = None,\n            response_id: Optional[str] = None,\n            store_callback: Optional[\n                Callable[[str, ResponsesRequest, ResponseObject], None]\n            ] = None,\n            browser_tool: Optional[SimpleBrowserTool] = None,\n            python_tool: Optional[PythonTool] = None,\n            functions_python_as_builtin: bool = False,\n        ):\n            self.initial_tokens = initial_tokens\n            self.tokens = initial_tokens.copy()\n            self.output_tokens = []\n            self.output_text = \"\"\n            self.request_body = request_body\n            self.parser = StreamableParser(encoding, role=Role.ASSISTANT)\n            self.as_sse = as_sse\n            self.debug_mode = request_body.metadata.get(\n                \"__debug\", False\n            )  # we use this for demo purposes\n            # Set temperature for this stream, fallback to DEFAULT_TEMPERATURE if not set\n            self.temperature = (\n                request_body.temperature\n                if request_body.temperature is not None\n                else DEFAULT_TEMPERATURE\n            )\n            self.request = request\n            self.sequence_number = 0\n            self.function_call_ids: list[tuple[str, str]] = []\n            self.response_id = response_id\n            self.store_callback = store_callback\n            self.new_request = True\n            self.browser_tool = browser_tool\n            self.use_browser_tool = browser_tool is not None\n            self.browser_call_ids: list[str] = []\n            self.python_tool = python_tool\n            self.use_code_interpreter = python_tool is not None\n            self.python_call_ids: list[str] = []\n            self.python_call_outputs: dict[\n                str, list[CodeInterpreterOutputLogs | CodeInterpreterOutputImage]\n            ] = {}\n            self.reasoning_item_ids: list[str] = []\n            self.current_reasoning_item_id: Optional[str] = None\n            self.message_item_ids: list[str] = []\n            self.current_message_item_id: Optional[str] = None\n            self.functions_python_as_builtin = functions_python_as_builtin\n            self.user_defined_function_names = {\n                name\n                for tool in (request_body.tools or [])\n                for name in [getattr(tool, \"name\", None)]\n                if getattr(tool, \"type\", None) == \"function\" and name\n            }\n\n        def _resolve_browser_recipient(\n            self, recipient: Optional[str]\n        ) -> tuple[Optional[str], bool]:\n            if not self.use_browser_tool or not recipient:\n                return (None, False)\n\n            if recipient.startswith(\"browser.\"):\n                return (recipient, False)\n\n            if recipient.startswith(\"functions.\"):\n                potential = recipient[len(\"functions.\") :]\n                if (\n                    potential in self.BROWSER_RESERVED_FUNCTIONS\n                    and potential not in self.user_defined_function_names\n                ):\n                    return (potential, True)\n\n            return (None, False)\n\n        def _ensure_message_item_id(self) -> str:\n            if self.current_message_item_id is None:\n                message_id = f\"item_{uuid.uuid4().hex}\"\n                self.current_message_item_id = message_id\n                self.message_item_ids.append(message_id)\n            return self.current_message_item_id\n\n        def _ensure_reasoning_item_id(self) -> str:\n            if self.current_reasoning_item_id is None:\n                reasoning_id = f\"rs_{uuid.uuid4().hex}\"\n                self.current_reasoning_item_id = reasoning_id\n                self.reasoning_item_ids.append(reasoning_id)\n            return self.current_reasoning_item_id\n\n        def _send_event(self, event: ResponseEvent):\n            event.sequence_number = self.sequence_number\n            self.sequence_number += 1\n            if self.as_sse:\n                return f\"event: {event.type}\\ndata: {event.model_dump_json(indent=None)}\\n\\n\"\n            else:\n                return event\n\n        async def run(self):\n            browser_tool = self.browser_tool\n            self.new_request = True\n            initial_response = generate_response(\n                self.initial_tokens,\n                self.output_tokens,\n                self.request_body,\n                function_call_ids=self.function_call_ids,\n                response_id=self.response_id,\n                previous_response_id=self.request_body.previous_response_id,\n                browser_tool=self.browser_tool,\n                browser_call_ids=self.browser_call_ids,\n                python_tool=self.python_tool,\n                python_call_ids=self.python_call_ids,\n                python_call_outputs=getattr(self, \"python_call_outputs\", None),\n                reasoning_ids=self.reasoning_item_ids,\n                message_ids=self.message_item_ids,\n                treat_functions_python_as_builtin=self.functions_python_as_builtin,\n            )\n            initial_response.status = \"in_progress\"\n            yield self._send_event(\n                ResponseCreatedEvent(\n                    type=\"response.created\",\n                    response=initial_response,\n                )\n            )\n            yield self._send_event(\n                ResponseInProgressEvent(\n                    type=\"response.in_progress\",\n                    response=initial_response,\n                )\n            )\n\n            current_content_index = (\n                0  # for this implementation we will always have one content item only\n            )\n            current_output_index = -1\n            sent_output_item_added = False\n\n            # we use this if the model outputs a citation to buffer until completed\n            output_delta_buffer = \"\"\n            # we use this to track the current output text content for things like providing the right indices in citations\n            current_output_text_content = \"\"\n            current_annotations = []\n\n            while True:\n                # Check for client disconnect\n                if self.request is not None and await self.request.is_disconnected():\n                    print(\"Client disconnected, stopping token generation.\")\n                    break\n                next_tok = infer_next_token(\n                    self.tokens,\n                    temperature=self.temperature,\n                    new_request=self.new_request,\n                )\n                self.new_request = False\n                self.tokens.append(next_tok)\n                try:\n                    self.parser.process(next_tok)\n                except Exception:\n                    pass\n\n                if self.parser.state == StreamState.EXPECT_START:\n                    current_output_index += 1\n                    sent_output_item_added = False\n\n                    if len(self.parser.messages) > 0:\n                        previous_item = self.parser.messages[-1]\n                        if previous_item.recipient is not None:\n                            recipient = previous_item.recipient\n                            browser_recipient, _ = self._resolve_browser_recipient(\n                                recipient\n                            )\n                            if (\n                                browser_recipient is None\n                                and not (\n                                    recipient == \"python\"\n                                    or (\n                                        self.functions_python_as_builtin\n                                        and recipient == \"functions.python\"\n                                    )\n                                )\n                            ):\n                                fc_id = f\"fc_{uuid.uuid4().hex}\"\n                                call_id = f\"call_{uuid.uuid4().hex}\"\n                                self.function_call_ids.append((fc_id, call_id))\n                                yield self._send_event(\n                                    ResponseOutputItemDone(\n                                        type=\"response.output_item.done\",\n                                        output_index=current_output_index,\n                                        item=FunctionCallItem(\n                                            type=\"function_call\",\n                                            name=(\n                                                previous_item.recipient[\n                                                    len(\"functions.\") :\n                                                ]\n                                                if previous_item.recipient.startswith(\n                                                    \"functions.\"\n                                                )\n                                                else previous_item.recipient\n                                            ),\n                                            arguments=previous_item.content[0].text,\n                                            id=fc_id,\n                                            call_id=call_id,\n                                        ),\n                                    )\n                                )\n                        if (\n                            previous_item.channel == \"analysis\"\n                            and previous_item.recipient is None\n                        ):\n                            reasoning_id = (\n                                self.current_reasoning_item_id\n                                if self.current_reasoning_item_id is not None\n                                else self._ensure_reasoning_item_id()\n                            )\n                            reasoning_text = previous_item.content[0].text\n                            yield self._send_event(\n                                ResponseReasoningTextDone(\n                                    type=\"response.reasoning_text.done\",\n                                    output_index=current_output_index,\n                                    content_index=current_content_index,\n                                    item_id=reasoning_id,\n                                    text=reasoning_text,\n                                )\n                            )\n                            yield self._send_event(\n                                ResponseContentPartDone(\n                                    type=\"response.content_part.done\",\n                                    output_index=current_output_index,\n                                    content_index=current_content_index,\n                                    item_id=reasoning_id,\n                                    part=ReasoningTextContentItem(\n                                        type=\"reasoning_text\",\n                                        text=reasoning_text,\n                                    ),\n                                )\n                            )\n                            yield self._send_event(\n                                ResponseOutputItemDone(\n                                    type=\"response.output_item.done\",\n                                    output_index=current_output_index,\n                                    item=ReasoningItem(\n                                        id=reasoning_id,\n                                        type=\"reasoning\",\n                                        summary=[],\n                                        content=[\n                                            ReasoningTextContentItem(\n                                                type=\"reasoning_text\",\n                                                text=reasoning_text,\n                                            )\n                                        ],\n                                    ),\n                                )\n                            )\n                            self.current_reasoning_item_id = None\n                        if previous_item.channel == \"final\":\n                            annotations = [\n                                UrlCitation(**a) for a in current_annotations\n                            ]\n                            if browser_tool:\n                                (\n                                    normalized_text,\n                                    _annotations,\n                                    _has_partial_citations,\n                                ) = browser_tool.normalize_citations(\n                                    previous_item.content[0].text\n                                )\n                            else:\n                                normalized_text = previous_item.content[0].text\n                                annotations = []\n                            text_content = TextContentItem(\n                                type=\"output_text\",\n                                text=normalized_text,\n                                annotations=annotations,\n                            )\n                            message_id = (\n                                self.current_message_item_id\n                                if self.current_message_item_id is not None\n                                else self._ensure_message_item_id()\n                            )\n                            yield self._send_event(\n                                ResponseOutputTextDone(\n                                    type=\"response.output_text.done\",\n                                    output_index=current_output_index,\n                                    content_index=current_content_index,\n                                    item_id=message_id,\n                                    text=normalized_text,\n                                )\n                            )\n                            yield self._send_event(\n                                ResponseContentPartDone(\n                                    type=\"response.content_part.done\",\n                                    output_index=current_output_index,\n                                    content_index=current_content_index,\n                                    item_id=message_id,\n                                    part=text_content,\n                                )\n                            )\n                            yield self._send_event(\n                                ResponseOutputItemDone(\n                                    type=\"response.output_item.done\",\n                                    output_index=current_output_index,\n                                    item=Item(\n                                        id=message_id,\n                                        type=\"message\",\n                                        role=\"assistant\",\n                                        content=[text_content],\n                                    ),\n                                )\n                            )\n                            current_annotations = []\n                            current_output_text_content = \"\"\n                            self.current_message_item_id = None\n\n                if (\n                    self.parser.last_content_delta\n                    and self.parser.current_channel == \"final\"\n                    and self.parser.current_recipient is None\n                ):\n                    if not sent_output_item_added:\n                        sent_output_item_added = True\n                        message_id = self._ensure_message_item_id()\n                        yield self._send_event(\n                            ResponseOutputItemAdded(\n                                type=\"response.output_item.added\",\n                                output_index=current_output_index,\n                                item=Item(\n                                    id=message_id,\n                                    type=\"message\",\n                                    role=\"assistant\",\n                                    content=[],\n                                ),\n                            )\n                        )\n                        yield self._send_event(\n                            ResponseContentPartAdded(\n                                type=\"response.content_part.added\",\n                                output_index=current_output_index,\n                                content_index=current_content_index,\n                                item_id=message_id,\n                                part=TextContentItem(type=\"output_text\", text=\"\"),\n                            )\n                        )\n\n                    output_delta_buffer += self.parser.last_content_delta\n                    should_send_output_text_delta = True\n                    if browser_tool:\n                        # we normalize on the full current text to get the right indices in citations\n                        updated_output_text, annotations, has_partial_citations = (\n                            browser_tool.normalize_citations(\n                                current_output_text_content + output_delta_buffer\n                            )\n                        )\n                        # remove the current text to get back the delta but now normalized\n                        output_delta_buffer = updated_output_text[\n                            len(current_output_text_content) :\n                        ]\n\n                        # Filter annotations to only include those whose start_index is not already present in current_annotations\n                        # this is to avoid sending duplicate annotations as multiple annotations can't be in the same place\n                        existing_start_indices = {\n                            a[\"start_index\"] for a in current_annotations\n                        }\n                        new_annotations = [\n                            a\n                            for a in annotations\n                            if a[\"start_index\"] not in existing_start_indices\n                        ]\n                        for a in new_annotations:\n                            current_annotations.append(a)\n                            citation = UrlCitation(**a)\n                            message_id = self._ensure_message_item_id()\n                            yield self._send_event(\n                                ResponseOutputTextAnnotationAdded(\n                                    type=\"response.output_text.annotation.added\",\n                                    output_index=current_output_index,\n                                    content_index=current_content_index,\n                                    item_id=message_id,\n                                    annotation_index=len(current_annotations),\n                                    annotation=citation,\n                                )\n                            )\n\n                        if has_partial_citations:\n                            should_send_output_text_delta = False\n\n                    if should_send_output_text_delta:\n                        message_id = self._ensure_message_item_id()\n                        yield self._send_event(\n                            ResponseOutputTextDelta(\n                                type=\"response.output_text.delta\",\n                                output_index=current_output_index,\n                                content_index=current_content_index,\n                                item_id=message_id,\n                                delta=output_delta_buffer,\n                            )\n                        )\n                        current_output_text_content += output_delta_buffer\n                        output_delta_buffer = \"\"\n\n                if (\n                    self.parser.last_content_delta\n                    and self.parser.current_channel == \"analysis\"\n                    and self.parser.current_recipient is None\n                ):\n                    if not sent_output_item_added:\n                        sent_output_item_added = True\n                        reasoning_id = self._ensure_reasoning_item_id()\n                        yield self._send_event(\n                            ResponseOutputItemAdded(\n                                type=\"response.output_item.added\",\n                                output_index=current_output_index,\n                                item=ReasoningItem(\n                                    id=reasoning_id,\n                                    type=\"reasoning\",\n                                    summary=[],\n                                    content=[],\n                                ),\n                            )\n                        )\n                        yield self._send_event(\n                            ResponseContentPartAdded(\n                                type=\"response.content_part.added\",\n                                output_index=current_output_index,\n                                content_index=current_content_index,\n                                item_id=reasoning_id,\n                                part=ReasoningTextContentItem(\n                                    type=\"reasoning_text\", text=\"\"\n                                ),\n                            )\n                        )\n                    reasoning_id = self._ensure_reasoning_item_id()\n                    yield self._send_event(\n                        ResponseReasoningTextDelta(\n                            type=\"response.reasoning_text.delta\",\n                            output_index=current_output_index,\n                            content_index=current_content_index,\n                            item_id=reasoning_id,\n                            delta=self.parser.last_content_delta,\n                        )\n                    )\n\n                try:\n                    # purely for debugging purposes\n                    output_token_text = encoding.decode_utf8([next_tok])\n                    self.output_text += output_token_text\n                    print(output_token_text, end=\"\", flush=True)\n\n                except RuntimeError:\n                    pass\n\n                if next_tok in encoding.stop_tokens_for_assistant_actions():\n                    if len(self.parser.messages) > 0:\n                        last_message = self.parser.messages[-1]\n                        browser_recipient, is_browser_fallback = (\n                            self._resolve_browser_recipient(last_message.recipient)\n                        )\n                        if browser_recipient is not None and browser_tool is not None:\n                            message_for_browser = (\n                                last_message\n                                if not is_browser_fallback\n                                else last_message.with_recipient(browser_recipient)\n                            )\n                            function_name = browser_recipient[len(\"browser.\") :]\n                            action = None\n                            parsed_args = browser_tool.process_arguments(\n                                message_for_browser\n                            )\n                            if function_name == \"search\":\n                                action = WebSearchActionSearch(\n                                    type=\"search\",\n                                    query=parsed_args[\"query\"],\n                                )\n                            elif function_name == \"open\":\n                                action = WebSearchActionOpenPage(\n                                    type=\"open_page\",\n                                    url=(\n                                        parsed_args[\"url\"]\n                                        if \"url\" in parsed_args\n                                        else None\n                                    ),\n                                )\n                            elif function_name == \"find\":\n                                action = WebSearchActionFind(\n                                    type=\"find\",\n                                    pattern=parsed_args[\"pattern\"],\n                                    url=(\n                                        parsed_args[\"url\"]\n                                        if \"url\" in parsed_args\n                                        else None\n                                    ),\n                                )\n\n                            if action is not None:\n                                web_search_call_id = f\"ws_{uuid.uuid4().hex}\"\n                                self.browser_call_ids.append(web_search_call_id)\n                                yield self._send_event(\n                                    ResponseOutputItemAdded(\n                                        type=\"response.output_item.added\",\n                                        output_index=current_output_index,\n                                        item=WebSearchCallItem(\n                                            type=\"web_search_call\",\n                                            id=web_search_call_id,\n                                            action=action,\n                                        ),\n                                    )\n                                )\n                            yield self._send_event(\n                                ResponseWebSearchCallInProgress(\n                                    type=\"response.web_search_call.in_progress\",\n                                    output_index=current_output_index,\n                                    item_id=web_search_call_id,\n                                )\n                            )\n\n                            async def run_tool():\n                                results = []\n                                async for msg in browser_tool.process(\n                                    message_for_browser\n                                ):\n                                    results.append(msg)\n                                return results\n\n                            yield self._send_event(\n                                ResponseWebSearchCallSearching(\n                                    type=\"response.web_search_call.searching\",\n                                    output_index=current_output_index,\n                                    item_id=web_search_call_id,\n                                )\n                            )\n                            result = await run_tool()\n\n                            new_tokens = encoding.render_conversation_for_completion(\n                                Conversation.from_messages(result), Role.ASSISTANT\n                            )\n\n                            print(encoding.decode_utf8(new_tokens))\n                            self.output_tokens.append(next_tok)\n                            self.tokens.append(\n                                encoding.encode(\"<|end|>\", allowed_special=\"all\")[0]\n                            )\n\n                            for token in new_tokens:\n                                self.parser.process(token)\n                                self.output_tokens.append(token)\n                                self.tokens.append(token)\n\n                            yield self._send_event(\n                                ResponseWebSearchCallCompleted(\n                                    type=\"response.web_search_call.completed\",\n                                    output_index=current_output_index,\n                                    item_id=web_search_call_id,\n                                )\n                            )\n                            yield self._send_event(\n                                ResponseOutputItemDone(\n                                    type=\"response.output_item.done\",\n                                    output_index=current_output_index,\n                                    item=WebSearchCallItem(\n                                        type=\"web_search_call\",\n                                        id=web_search_call_id,\n                                        action=action,\n                                    ),\n                                )\n                            )\n\n                            current_output_index += 1\n                            self.new_request = True\n\n                            continue\n\n                        elif (\n                            self.use_code_interpreter\n                            and last_message.recipient is not None\n                            and (\n                                last_message.recipient.startswith(\"python\")\n                                or (\n                                    self.functions_python_as_builtin\n                                    and last_message.recipient == \"functions.python\"\n                                )\n                            )\n                        ):\n                            code_call_id = f\"ci_{uuid.uuid4().hex}\"\n                            code_snippet = None\n                            if (\n                                last_message.content\n                                and len(last_message.content) > 0\n                                and getattr(last_message.content[0], \"text\", None)\n                            ):\n                                text_value = last_message.content[0].text or \"\"\n                                code_snippet = text_value if text_value.strip() else None\n\n                            self.python_call_ids.append(code_call_id)\n                            yield self._send_event(\n                                ResponseOutputItemAdded(\n                                    type=\"response.output_item.added\",\n                                    output_index=current_output_index,\n                                    item=CodeInterpreterCallItem(\n                                        type=\"code_interpreter_call\",\n                                        id=code_call_id,\n                                        status=\"in_progress\",\n                                        code=code_snippet,\n                                    ),\n                                )\n                            )\n                            yield self._send_event(\n                                ResponseCodeInterpreterCallInProgress(\n                                    type=\"response.code_interpreter_call.in_progress\",\n                                    output_index=current_output_index,\n                                    item_id=code_call_id,\n                                )\n                            )\n                            if code_snippet:\n                                yield self._send_event(\n                                    ResponseCodeInterpreterCallCodeDelta(\n                                        type=\"response.code_interpreter_call_code.delta\",\n                                        output_index=current_output_index,\n                                        item_id=code_call_id,\n                                        delta=code_snippet,\n                                    )\n                                )\n                                yield self._send_event(\n                                    ResponseCodeInterpreterCallCodeDone(\n                                        type=\"response.code_interpreter_call_code.done\",\n                                        output_index=current_output_index,\n                                        item_id=code_call_id,\n                                        code=code_snippet,\n                                    )\n                                )\n                            yield self._send_event(\n                                ResponseCodeInterpreterCallInterpreting(\n                                    type=\"response.code_interpreter_call.interpreting\",\n                                    output_index=current_output_index,\n                                    item_id=code_call_id,\n                                )\n                            )\n\n                            async def run_python_tool():\n                                results = []\n                                async for msg in self.python_tool.process(last_message):\n                                    results.append(msg)\n                                return results\n\n                            result = await run_python_tool()\n\n                            print(result)\n\n                            code_outputs: list[\n                                CodeInterpreterOutputLogs | CodeInterpreterOutputImage\n                            ] = []\n                            for message in result:\n                                for content in getattr(message, \"content\", []):\n                                    text_value = getattr(content, \"text\", None)\n                                    if text_value:\n                                        code_outputs.append(\n                                            CodeInterpreterOutputLogs(\n                                                type=\"logs\",\n                                                logs=text_value,\n                                            )\n                                        )\n\n                            self.python_call_outputs[code_call_id] = code_outputs\n\n                            new_tokens = encoding.render_conversation_for_completion(\n                                Conversation.from_messages(result), Role.ASSISTANT\n                            )\n\n                            print(encoding.decode_utf8(new_tokens))\n                            self.output_tokens.append(next_tok)\n                            self.tokens.append(\n                                encoding.encode(\"<|end|>\", allowed_special=\"all\")[0]\n                            )\n\n                            for token in new_tokens:\n                                self.parser.process(token)\n                                self.output_tokens.append(token)\n                                self.tokens.append(token)\n\n                            yield self._send_event(\n                                ResponseCodeInterpreterCallCompleted(\n                                    type=\"response.code_interpreter_call.completed\",\n                                    output_index=current_output_index,\n                                    item_id=code_call_id,\n                                )\n                            )\n                            yield self._send_event(\n                                ResponseOutputItemDone(\n                                    type=\"response.output_item.done\",\n                                    output_index=current_output_index,\n                                    item=CodeInterpreterCallItem(\n                                        type=\"code_interpreter_call\",\n                                        id=code_call_id,\n                                        status=\"completed\",\n                                        code=code_snippet,\n                                        outputs=code_outputs or None,\n                                    ),\n                                )\n                            )\n\n                            current_output_index += 1\n                            self.new_request = True\n\n                            continue\n\n                        else:\n                            break\n                    else:\n                        raise ValueError(\"No messages to process\")\n                if len(self.output_tokens) >= self.request_body.max_output_tokens:\n                    break\n\n                # Adding in the end if we know we are not done\n                self.output_tokens.append(next_tok)\n\n            if self.request is None or not await self.request.is_disconnected():\n                response = generate_response(\n                    self.initial_tokens,\n                    self.output_tokens,\n                    self.request_body,\n                    debug_mode=self.debug_mode,\n                    function_call_ids=self.function_call_ids,\n                    response_id=self.response_id,\n                    previous_response_id=self.request_body.previous_response_id,\n                    browser_tool=self.browser_tool,\n                    browser_call_ids=self.browser_call_ids,\n                    python_tool=self.python_tool,\n                    python_call_ids=self.python_call_ids,\n                    python_call_outputs=self.python_call_outputs,\n                    reasoning_ids=self.reasoning_item_ids,\n                    message_ids=self.message_item_ids,\n                    treat_functions_python_as_builtin=self.functions_python_as_builtin,\n                )\n                if self.store_callback and self.request_body.store:\n                    self.store_callback(self.response_id, self.request_body, response)\n                yield self._send_event(\n                    ResponseCompletedEvent(\n                        type=\"response.completed\",\n                        response=response,\n                    )\n                )\n\n    @app.post(\"/v1/responses\", response_model=ResponseObject)\n    async def generate(body: ResponsesRequest, request: Request):\n        print(\"request received\")\n        print(body.reasoning)\n\n        use_browser_tool = any(\n            getattr(tool, \"type\", None) in (\"browser_search\", \"web_search\")\n            for tool in (body.tools or [])\n        )\n        use_code_interpreter = any(\n            getattr(tool, \"type\", None) == \"code_interpreter\"\n            for tool in (body.tools or [])\n        )\n\n        if use_browser_tool:\n            tool_backend = os.getenv(\"BROWSER_BACKEND\", \"exa\")\n            if tool_backend == \"youcom\":\n                backend = YouComBackend(source=\"web\")\n            elif tool_backend == \"exa\":\n                backend = ExaBackend(source=\"web\")\n            else:\n                raise ValueError(f\"Invalid tool backend: {tool_backend}\")\n            browser_tool = SimpleBrowserTool(backend=backend)\n        else:\n            browser_tool = None\n\n        if use_code_interpreter:\n            python_tool = PythonTool()\n        else:\n            python_tool = None\n\n        python_function_name_conflict = any(\n            getattr(tool, \"type\", None) == \"function\"\n            and getattr(tool, \"name\", None) == \"python\"\n            for tool in (body.tools or [])\n        )\n        functions_python_as_builtin = use_code_interpreter and not (\n            python_function_name_conflict\n        )\n\n        if body.previous_response_id:\n            prev = responses_store.get(body.previous_response_id)\n            if prev:\n                prev_req, prev_resp = prev\n\n                def _ensure_list(inp):\n                    if isinstance(inp, str):\n                        return [\n                            Item(\n                                type=\"message\",\n                                role=\"user\",\n                                content=[TextContentItem(type=\"input_text\", text=inp)],\n                            )\n                        ]\n                    return list(inp)\n\n                merged_input = _ensure_list(prev_req.input) + list(prev_resp.output)\n                merged_input.extend(_ensure_list(body.input))\n\n                if body.instructions is None:\n                    body.instructions = prev_req.instructions\n                body.input = merged_input\n\n        system_message_content = SystemContent.new().with_conversation_start_date(\n            datetime.datetime.now().strftime(\"%Y-%m-%d\")\n        )\n\n        if body.reasoning is not None:\n            try:\n\n                reasoning_effort = get_reasoning_effort(body.reasoning.effort)\n            except ValueError as e:\n                from fastapi import HTTPException\n                print(e)\n\n                raise HTTPException(status_code=422, detail=str(e))\n            system_message_content = system_message_content.with_reasoning_effort(\n                reasoning_effort\n            )\n\n        if use_browser_tool:\n            system_message_content = system_message_content.with_tools(\n                browser_tool.tool_config\n            )\n        if use_code_interpreter:\n            system_message_content = system_message_content.with_tools(\n                python_tool.tool_config\n            )\n\n        system_message = Message.from_role_and_content(\n            Role.SYSTEM, system_message_content\n        )\n        messages = [system_message]\n\n        if body.instructions or body.tools:\n            developer_message_content = DeveloperContent.new().with_instructions(\n                body.instructions\n            )\n\n            tools = []\n            for tool in body.tools:\n                if tool.type == \"function\":\n                    tools.append(\n                        ToolDescription.new(\n                            tool.name,\n                            tool.description,\n                            tool.parameters,\n                        )\n                    )\n\n            if tools:\n                developer_message_content = (\n                    developer_message_content.with_function_tools(tools)\n                )\n\n            developer_message = Message.from_role_and_content(\n                Role.DEVELOPER, developer_message_content\n            )\n\n            messages.append(developer_message)\n\n        if isinstance(body.input, str):\n            user_message = Message.from_role_and_content(Role.USER, body.input)\n            messages.append(user_message)\n        else:\n            is_last_message_function_call_output = (\n                len(body.input) > 0 and body.input[-1].type == \"function_call_output\"\n            )\n            function_call_map = {}\n            # Find the index of the last assistant message\n            last_assistant_idx = -1\n            for idx, item in enumerate(body.input):\n                if item.type == \"message\" and item.role == Role.ASSISTANT:\n                    last_assistant_idx = idx\n\n            for idx, item in enumerate(body.input):\n                if item.type == \"message\":\n                    # TODO: add system prompt handling\n                    if isinstance(item.content, str):\n                        messages.append(\n                            Message.from_role_and_content(item.role, item.content)\n                        )\n                    else:\n                        for content_item in item.content:\n                            messages.append(\n                                Message.from_role_and_content(\n                                    item.role, content_item.text\n                                )\n                            )\n                    # add final channel to the last assistant message if it's from the assistant\n                    if item.role == Role.ASSISTANT:\n                        messages[-1] = messages[-1].with_channel(\"final\")\n                elif item.type == \"reasoning\":\n                    # Only include reasoning if it is after the last assistant message and we are handling a function call at the moment\n                    if (\n                        idx > last_assistant_idx\n                        and is_last_message_function_call_output\n                    ):\n                        for content_item in item.content:\n                            messages.append(\n                                Message.from_role_and_content(\n                                    Role.ASSISTANT, content_item.text\n                                ).with_channel(\"analysis\")\n                            )\n                elif item.type == \"function_call\":\n                    function_call_map[item.call_id] = item\n                    messages.append(\n                        Message.from_role_and_content(Role.ASSISTANT, item.arguments)\n                        .with_recipient(f\"functions.{item.name}\")\n                        .with_channel(\"commentary\")\n                    )\n                elif item.type == \"function_call_output\":\n                    function_call = function_call_map.get(item.call_id, None)\n                    if not function_call:\n                        raise ValueError(f\"Function call {item.call_id} not found\")\n\n                    messages.append(\n                        Message.from_author_and_content(\n                            Author.new(Role.TOOL, f\"functions.{function_call.name}\"),\n                            item.output,\n                        )\n                        .with_recipient(\"assistant\")\n                        .with_channel(\"commentary\")\n                    )\n\n        conversation = Conversation.from_messages(messages)\n\n        initial_tokens = encoding.render_conversation_for_completion(\n            conversation, Role.ASSISTANT\n        )\n        print(encoding.decode_utf8(initial_tokens))\n        response_id = f\"resp_{uuid.uuid4().hex}\"\n\n        def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject):\n            responses_store[rid] = (req, resp)\n\n        event_stream = StreamResponsesEvents(\n            initial_tokens,\n            body,\n            as_sse=body.stream,\n            request=request,\n            response_id=response_id,\n            store_callback=store_callback,\n            browser_tool=browser_tool,\n            python_tool=python_tool,\n            functions_python_as_builtin=functions_python_as_builtin,\n        )\n\n        if body.stream:\n            return StreamingResponse(event_stream.run(), media_type=\"text/event-stream\")\n        else:\n            last_event = None\n            async for event in event_stream.run():\n                last_event = event\n\n            return last_event.response\n\n    return app\n"
  },
  {
    "path": "gpt_oss/responses_api/events.py",
    "content": "# torchrun --nproc-per-node=4 responses_api.py\nfrom typing import Literal, Optional, Union\n\nfrom pydantic import BaseModel\n\nfrom .types import (\n    CodeInterpreterCallItem,\n    CodeInterpreterOutputImage,\n    CodeInterpreterOutputLogs,\n    FunctionCallItem,\n    Item,\n    ReasoningItem,\n    ReasoningTextContentItem,\n    ResponseObject,\n    TextContentItem,\n    UrlCitation,\n    WebSearchCallItem,\n)\n\n\nclass ResponseEvent(BaseModel):\n    sequence_number: Optional[int] = 1\n\n\nclass ResponseCreatedEvent(ResponseEvent):\n    type: Literal[\"response.created\"]\n    response: ResponseObject\n\n\nclass ResponseCompletedEvent(ResponseEvent):\n    type: Literal[\"response.completed\"]\n    response: ResponseObject\n\n\nclass ResponseOutputTextDelta(ResponseEvent):\n    type: Literal[\"response.output_text.delta\"] = \"response.output_text.delta\"\n    item_id: str = \"item_1234\"\n    output_index: int = 0\n    content_index: int = 0\n    delta: str = \"\"\n    logprobs: list = []\n\n\nclass ResponseReasoningSummaryTextDelta(ResponseEvent):\n    type: Literal[\"response.reasoning_summary_text.delta\"] = (\n        \"response.reasoning_summary_text.delta\"\n    )\n    item_id: str = \"item_1234\"\n    output_index: int = 0\n    content_index: int = 0\n    delta: str = \"\"\n\n\nclass ResponseReasoningTextDelta(ResponseEvent):\n    type: Literal[\"response.reasoning_text.delta\"] = \"response.reasoning_text.delta\"\n    item_id: str = \"item_1234\"\n    output_index: int = 0\n    content_index: int = 0\n    delta: str = \"\"\n\n\nclass ResponseReasoningTextDone(ResponseEvent):\n    type: Literal[\"response.reasoning_text.done\"] = \"response.reasoning_text.done\"\n    item_id: str = \"item_1234\"\n    output_index: int = 0\n    content_index: int = 0\n    text: str = \"\"\n\n\nclass ResponseOutputItemAdded(ResponseEvent):\n    type: Literal[\"response.output_item.added\"] = \"response.output_item.added\"\n    output_index: int = 0\n    item: Union[\n        Item,\n        ReasoningItem,\n        FunctionCallItem,\n        WebSearchCallItem,\n        CodeInterpreterCallItem,\n    ]\n\n\nclass ResponseOutputItemDone(ResponseEvent):\n    type: Literal[\"response.output_item.done\"] = \"response.output_item.done\"\n    output_index: int = 0\n    item: Union[\n        Item,\n        ReasoningItem,\n        FunctionCallItem,\n        WebSearchCallItem,\n        CodeInterpreterCallItem,\n    ]\n\n\nclass ResponseInProgressEvent(ResponseEvent):\n    type: Literal[\"response.in_progress\"]\n    response: ResponseObject\n\n\nclass ResponseContentPartAdded(ResponseEvent):\n    type: Literal[\"response.content_part.added\"] = \"response.content_part.added\"\n    item_id: str = \"item_1234\"\n    output_index: int = 0\n    content_index: int = 0\n    part: Union[TextContentItem, ReasoningTextContentItem]\n\n\nclass ResponseOutputTextDone(ResponseEvent):\n    type: Literal[\"response.output_text.done\"] = \"response.output_text.done\"\n    item_id: str = \"item_1234\"\n    output_index: int = 0\n    content_index: int = 0\n    text: str = \"\"\n    logprobs: list = []\n\n\nclass ResponseContentPartDone(ResponseEvent):\n    type: Literal[\"response.content_part.done\"] = \"response.content_part.done\"\n    item_id: str = \"item_1234\"\n    output_index: int = 0\n    content_index: int = 0\n    part: Union[TextContentItem, ReasoningTextContentItem]\n\n\nclass ResponseOutputTextAnnotationAdded(ResponseEvent):\n    type: Literal[\"response.output_text.annotation.added\"] = (\n        \"response.output_text.annotation.added\"\n    )\n    item_id: str = \"item_1234\"\n    output_index: int = 0\n    content_index: int = 0\n    annotation_index: int = 0\n    annotation: UrlCitation\n\n\nclass ResponseWebSearchCallInProgress(ResponseEvent):\n    type: Literal[\"response.web_search_call.in_progress\"] = (\n        \"response.web_search_call.in_progress\"\n    )\n    output_index: int = 0\n    item_id: str = \"item_1234\"\n\n\nclass ResponseWebSearchCallSearching(ResponseEvent):\n    type: Literal[\"response.web_search_call.searching\"] = (\n        \"response.web_search_call.searching\"\n    )\n    output_index: int = 0\n    item_id: str = \"item_1234\"\n\n\nclass ResponseWebSearchCallCompleted(ResponseEvent):\n    type: Literal[\"response.web_search_call.completed\"] = (\n        \"response.web_search_call.completed\"\n    )\n    output_index: int = 0\n    item_id: str = \"item_1234\"\n\n\nclass ResponseCodeInterpreterCallInProgress(ResponseEvent):\n    type: Literal[\"response.code_interpreter_call.in_progress\"] = (\n        \"response.code_interpreter_call.in_progress\"\n    )\n    output_index: int = 0\n    item_id: str = \"item_1234\"\n\n\nclass ResponseCodeInterpreterCallInterpreting(ResponseEvent):\n    type: Literal[\"response.code_interpreter_call.interpreting\"] = (\n        \"response.code_interpreter_call.interpreting\"\n    )\n    output_index: int = 0\n    item_id: str = \"item_1234\"\n\n\nclass ResponseCodeInterpreterCallCodeDelta(ResponseEvent):\n    type: Literal[\"response.code_interpreter_call_code.delta\"] = (\n        \"response.code_interpreter_call_code.delta\"\n    )\n    output_index: int = 0\n    item_id: str = \"item_1234\"\n    delta: str = \"\"\n    code_output: Optional[\n        Union[CodeInterpreterOutputLogs, CodeInterpreterOutputImage]\n    ] = None\n\n\nclass ResponseCodeInterpreterCallCodeDone(ResponseEvent):\n    type: Literal[\"response.code_interpreter_call_code.done\"] = (\n        \"response.code_interpreter_call_code.done\"\n    )\n    output_index: int = 0\n    item_id: str = \"item_1234\"\n    code: str = \"\"\n    outputs: Optional[\n        list[Union[CodeInterpreterOutputLogs, CodeInterpreterOutputImage]]\n    ] = None\n\n\nclass ResponseCodeInterpreterCallCompleted(ResponseEvent):\n    type: Literal[\"response.code_interpreter_call.completed\"] = (\n        \"response.code_interpreter_call.completed\"\n    )\n    output_index: int = 0\n    item_id: str = \"item_1234\"\n"
  },
  {
    "path": "gpt_oss/responses_api/inference/__init__.py",
    "content": ""
  },
  {
    "path": "gpt_oss/responses_api/inference/metal.py",
    "content": "\"\"\"Metal backend for :mod:`gpt_oss.responses_api`.\"\"\"\n\nfrom typing import Callable\n\nfrom gpt_oss.metal import Context, Model\n\n\n# Tunables\nMAX_OUTPUT_TOKENS = 100\n\n\ndef setup_model(checkpoint: str) -> Callable[[list[int], float], int]:\n    \"\"\"Load the Metal model and return an inference function.\"\"\"\n\n    model = Model(checkpoint)\n    context = Context(model)\n\n    seed = 0\n    output_tokens = []\n\n    def infer_next_token(\n        tokens: list[int], temperature: float = 0.0, new_request: bool = False\n    ) -> int:\n        \"\"\"Infer next token using incremental LCP caching when possible.\"\"\"\n        nonlocal output_tokens\n\n        if new_request:\n            output_tokens = []\n\n        if len(output_tokens) == 0:\n            # Context handles LCP caching internally; if `tokens` matches the\n            # tokens in the KV cache, the KV cache is reused after reset+append.\n            context.reset()\n            for t in tokens:\n                context.append(t)\n\n            output_tokens = context.sample(max_output_tokens=MAX_OUTPUT_TOKENS,\n                                           temperature=temperature,\n                                           seed=seed)\n\n        return int(output_tokens.pop(0))\n\n    return infer_next_token\n"
  },
  {
    "path": "gpt_oss/responses_api/inference/ollama.py",
    "content": "\"\"\"\nNOTE: this is a stitched together implementation that uses Ollama for inference. It's primarily used\nfor testing and development. It does not leverage any prompt caching or other optimizations and\ncan therefore be slow between turns.\n\"\"\"\n\nimport json\nimport threading\nimport time\nfrom typing import Callable, Optional\n\nimport requests\nfrom openai_harmony import HarmonyEncodingName, load_harmony_encoding\n\nEOS_TOKEN = 200002  # only used on hard timeout\n\n# Tunables\nPOLL_INTERVAL_S = 0.01  # 10ms between buffer checks\nCALL_MAX_WAIT_S = 0.250  # max time to block inside a single infer call\nNO_TOKEN_TIMEOUT_S = 15.0  # overall inactivity timeout before emitting EOS\nFIRST_BYTE_TIMEOUT_S = 30.0  # time to wait for first token before EOS\n\n# Shared state\n_token_buffer: list[int] = []\n_buffer_lock = threading.Lock()\n_stream_thread: Optional[threading.Thread] = None\n_stream_done = threading.Event()\n_stream_error: Optional[Exception] = None\n_last_progress_ts: float = 0.0  # updated whenever we enqueue or dequeue tokens\n_previous_request_tokens: list[int] = []\n\n\ndef lcp(cache: list[int], inp: list[int]) -> list[int]:\n    i = 0\n    max_len = min(len(cache), len(inp))\n    while i < max_len and cache[i] == inp[i]:\n        i += 1\n    return cache[:i]\n\n\ndef _now():\n    return time.monotonic()\n\n\ndef _touch_progress():\n    global _last_progress_ts\n    _last_progress_ts = _now()\n\n\ndef _reset_stream_state():\n    global _token_buffer, _stream_thread, _stream_error\n    with _buffer_lock:\n        _token_buffer = []\n    _stream_done.clear()\n    _stream_thread = None\n    _stream_error = None\n    _touch_progress()\n\n\ndef setup_model(checkpoint: str) -> Callable[[list[int], float, bool], int]:\n    encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n    model_name = checkpoint\n\n    def _start_stream(token_ids: list[int], temperature: float):\n        prompt_text = encoding.decode(token_ids)\n\n        def run():\n            nonlocal prompt_text, temperature\n            global _stream_error\n            global _previous_request_tokens\n\n            accum_text = \"\"\n            last_len = 0  # number of tokens already emitted\n\n            try:\n                url = \"http://localhost:11434/api/generate\"\n\n                payload = {\n                    \"model\": model_name,\n                    \"prompt\": prompt_text,\n                    \"stream\": True,\n                    \"options\": {\"temperature\": temperature},\n                    \"raw\": True,\n                }\n\n                with requests.post(url, json=payload, stream=True, timeout=60) as resp:\n                    resp.raise_for_status()\n                    for line in resp.iter_lines(decode_unicode=True):\n                        if not line:\n                            continue\n                        obj = json.loads(line)\n\n                        if isinstance(obj.get(\"response\"), str):\n                            accum_text += obj[\"response\"]\n                            toks = encoding.encode(accum_text, allowed_special=\"all\")\n                            if len(toks) > last_len:\n                                new_toks = toks[last_len:]\n                                with _buffer_lock:\n                                    _token_buffer.extend(new_toks)\n                                last_len = len(toks)\n                                _touch_progress()\n\n                        if obj.get(\"done\", False):\n                            _token_buffer.append(EOS_TOKEN)\n                            last_len = len(toks)\n                            _touch_progress()\n                            break\n\n                _stream_done.set()\n\n            except Exception as e:\n                _stream_error = e\n                _stream_done.set()\n\n        t = threading.Thread(target=run, name=\"ollama-stream\", daemon=True)\n        t.start()\n        return t\n\n    def infer_next_token(\n        tokens: list[int], temperature: float = 0.0, new_request: bool = False\n    ) -> int:\n        \"\"\"\n        - Starts a new Ollama stream on new_request.\n        - Forwards tokens as they arrive.\n        - Only emits EOS_TOKEN if we exceed an inactivity timeout.\n        \"\"\"\n        global _stream_thread\n\n        if new_request:\n            _reset_stream_state()\n            _stream_thread = _start_stream(token_ids=tokens, temperature=temperature)\n            # Wait for first byte within FIRST_BYTE_TIMEOUT_S (without emitting EOS early)\n            start = _now()\n            while _now() - start < FIRST_BYTE_TIMEOUT_S:\n                with _buffer_lock:\n                    if _token_buffer:\n                        tok = _token_buffer.pop(0)\n                        _touch_progress()\n                        return tok\n                if _stream_error is not None:\n                    raise RuntimeError(f\"Ollama stream error: {_stream_error!r}\")\n                # If Ollama finished instantly with no output, continue loop until timeout\n                time.sleep(POLL_INTERVAL_S)\n            # Hard first-byte timeout -> emit EOS so the server can stop this request\n            return EOS_TOKEN\n\n        if _stream_error is not None:\n            raise RuntimeError(f\"Ollama stream error: {_stream_error!r}\")\n\n        # Normal path: wait up to CALL_MAX_WAIT_S for a token to arrive\n        wait_start = _now()\n        while _now() - wait_start < CALL_MAX_WAIT_S:\n            with _buffer_lock:\n                if _token_buffer:\n                    tok = _token_buffer.pop(0)\n                    _touch_progress()\n                    return tok\n            # No token yet; if we've been idle too long overall, end with EOS\n            if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:\n                return EOS_TOKEN\n            time.sleep(POLL_INTERVAL_S)\n\n        # Still no token in this call slice. Do NOT send EOS unless we've timed out.\n        if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:\n            return EOS_TOKEN\n\n        # Tell caller to call us again; block minimally by returning *nothing new*.\n        # We must return an int; safest is to wait a tiny bit longer for a token.\n        # If still none, keep returning only after short waits. Avoid EOS here.\n        # One more short wait to reduce hot-looping:\n        time.sleep(POLL_INTERVAL_S)\n        with _buffer_lock:\n            if _token_buffer:\n                tok = _token_buffer.pop(0)\n                _touch_progress()\n                return tok\n\n        # As a last resort for this call slice, return EOS only on true inactivity timeout.\n        if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:\n            return EOS_TOKEN\n\n        # If we reach here, we still haven't got a token—ask the caller to call again soon.\n        # Return a harmless token that the server will replace/ignore if your interface supports it.\n        # If your interface does NOT allow a sentinel, keep the short-blocking behavior above.\n        return (\n            EOS_TOKEN if False else 0\n        )  # replace `0` with a PAD/NOOP token your server ignores\n\n    return infer_next_token\n"
  },
  {
    "path": "gpt_oss/responses_api/inference/stub.py",
    "content": "import time\nfrom typing import Callable\n\nfake_tokens = [\n    200005,\n    35644,\n    200008,\n    23483,\n    316,\n    1199,\n    1114,\n    717,\n    170154,\n    13,\n    200007,\n    200006,\n    173781,\n    200005,\n    35644,\n    316,\n    28,\n    44580,\n    775,\n    170154,\n    464,\n    91,\n    542,\n    141043,\n    91,\n    29,\n    4108,\n    200008,\n    10848,\n    7693,\n    7534,\n    28499,\n    18826,\n    18583,\n    200012,\n]\nfake_tokens = [\n    200005,\n    35644,\n    200008,\n    1844,\n    31064,\n    25,\n    392,\n    4827,\n    382,\n    220,\n    17,\n    659,\n    220,\n    17,\n    16842,\n    12295,\n    81645,\n    13,\n    51441,\n    6052,\n    13,\n    200007,\n    200006,\n    173781,\n    200005,\n    17196,\n    200008,\n    17,\n    659,\n    220,\n    17,\n    314,\n    220,\n    19,\n    13,\n    9552,\n    238,\n    242,\n    200002,\n]\n# fake_tokens = [200005, 35644, 200008, 976, 1825, 31064, 25, 392, 25216, 29400, 290, 11122, 306, 52768, 2117, 16842, 1416, 1309, 316, 2281, 198, 68, 290, 2208, 11122, 13, 1416, 679, 261, 1114, 717, 170154, 484, 44390, 261, 5100, 1621, 26, 581, 1757, 2005, 198, 75, 480, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 7801, 4733, 290, 11122, 5377, 484, 290, 1114, 7377, 13, 1416, 1309, 260, 198, 78, 1199, 290, 1114, 4584, 364, 58369, 2421, 717, 170154, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 200007, 200006, 173781, 200005, 12606, 815, 260, 198, 78, 28, 117673, 3490]\n# fake_tokens = [\n#     198,\n#     200005,\n#     35644,\n#     200008,\n#     23483,\n#     316,\n#     1199,\n#     1114,\n#     717,\n#     170154,\n#     13,\n#     200007,\n#     200006,\n#     173781,\n#     200005,\n#     12606,\n#     815,\n#     316,\n#     32455,\n#     106847,\n#     316,\n#     28,\n#     44580,\n#     775,\n#     170154,\n#     464,\n#     91,\n#     542,\n#     141043,\n#     91,\n#     29,\n#     4108,\n#     200008,\n#     10848,\n#     7693,\n#     7534,\n#     28499,\n#     18826,\n#     18583,\n#     200012,\n#     198,\n# ]\n\ntoken_queue = fake_tokens.copy()\n\n\ndef stub_infer_next_token(\n    tokens: list[int], temperature: float = 0.0, new_request: bool = False\n) -> int:\n    global token_queue\n    next_tok = token_queue.pop(0)\n    if len(token_queue) == 0:\n        token_queue = fake_tokens.copy()\n    time.sleep(0.1)\n    return next_tok\n\n\ndef setup_model(_checkpoint: str) -> Callable[[list[int], float], int]:\n    return stub_infer_next_token\n"
  },
  {
    "path": "gpt_oss/responses_api/inference/transformers.py",
    "content": "\"\"\"\nNOTE: this is not the most efficient way to use transformers. It's a simple implementation that infers\none token at a time to mimic the behavior of the Triton implementation.\n\"\"\"\n\nimport os\nfrom typing import Callable, List\n\n# Transformers imports\nfrom transformers import AutoModelForCausalLM, PreTrainedModel\nimport torch\n\n\nDEFAULT_TEMPERATURE = 0.0\nTP = os.environ.get(\"TP\", 2)\n\ndef load_model(checkpoint: str):\n    \"\"\"\n    Serve the model directly with the Auto API.\n    \"\"\"\n\n    model = AutoModelForCausalLM.from_pretrained(\n        checkpoint,\n        torch_dtype=torch.bfloat16,\n        device_map=\"auto\",\n    )\n\n    return model\n\n\ndef get_infer_next_token(model: PreTrainedModel):\n    \"\"\"\n    Return a callable with the same shape as the original triton implementation:\n      infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int\n\n    Implementation detail:\n      - We issue a single-token generation with using model.generate\n      - generate handles sampling (temperature=0 => greedy, otherwise, sampling).\n    \"\"\"\n\n    def infer_next_token(\n        tokens: List[int],\n        temperature: float = DEFAULT_TEMPERATURE,\n        new_request: bool = False, # kept for interface compatibility; unused here\n    ) -> int:\n        tokens = torch.tensor([tokens], dtype=torch.int64, device=model.device)\n        output = model.generate(tokens, max_new_tokens=1, do_sample=temperature != 0, temperature=temperature)\n        return output[0, -1].tolist()\n\n    return infer_next_token\n\n\ndef setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:\n    model = load_model(checkpoint)\n    infer_next_token = get_infer_next_token(model)\n    return infer_next_token\n"
  },
  {
    "path": "gpt_oss/responses_api/inference/triton.py",
    "content": "import datetime\nimport os\nfrom typing import Callable\n\nos.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\nimport torch\nimport torch.distributed as dist\n\nfrom gpt_oss.triton.model import Cache, ModelConfig, Transformer\n\nDEFAULT_TEMPERATURE = 0.0\nCONTEXT = 16_384\nCONCURRENT_SESSIONS = 1\n\nrank = int(\n    os.environ.get(\"RANK\", 0)\n)  # set this env var to another value to run on other GPUs\n\n\ndef load_model(checkpoint: str):\n    print(f\"[{rank}] loading model...\")\n\n    torch.cuda.set_device(rank)\n    torch.set_grad_enabled(False)\n    device = torch.device(f\"cuda:{rank}\")\n\n    # Load model\n    model = Transformer.from_checkpoint(checkpoint, device=device)\n\n    print(f\"[{rank}] loaded\")\n    return model, device\n\n\ndef get_infer_next_token(model, device):\n    caches = [\n        Cache(CONCURRENT_SESSIONS, CONTEXT, model.config.num_key_value_heads)\n        for _ in range(len(model.block))\n    ]\n    # offsets = torch.zeros(CONCURRENT_SESSIONS, dtype=torch.int32, device=device) # TBD\n    input_token = torch.zeros(\n        1, dtype=torch.int32, device=device\n    )  # add concurrent sessions support\n    tokens_so_far = []\n\n    model.prefill(torch.zeros(1, 4, dtype=torch.int32, device=device), caches)\n    graph = torch.cuda.CUDAGraph()\n    with torch.cuda.graph(graph):\n        logits = model(input_token[None, :], caches=caches)[0]\n\n    def lcp(cache: list[int], inp: list[int]) -> list[int]:\n        i = 0\n        max_len = min(len(cache), len(inp))\n        while i < max_len and cache[i] == inp[i]:\n            i += 1\n        return cache[:i]\n\n    def sample_next_token(\n        logits: torch.Tensor, temperature: float = DEFAULT_TEMPERATURE\n    ) -> int:\n        \"\"\"Executed only on rank 0.\"\"\"\n        if temperature == 0.0:\n            return torch.argmax(logits[-1, :], dim=-1).item()\n        probs = torch.softmax(logits * (1.0 / temperature), dim=-1)\n        return torch.multinomial(probs[-1, :], num_samples=1).item()\n\n    @torch.inference_mode()\n    def infer_next_token(\n        tokens: list[int],\n        temperature: float = DEFAULT_TEMPERATURE,\n        new_request: bool = False,\n    ) -> int:\n        nonlocal tokens_so_far\n        tokens_so_far = lcp(tokens_so_far, tokens)\n        for cache in caches:\n            cache.truncate(len(tokens_so_far))\n        all_tokens = tokens  # for pdb\n        tokens = tokens[len(tokens_so_far) :]\n\n        if len(tokens) > 1:\n            model.prefill(\n                torch.as_tensor(tokens[:-1], dtype=torch.int32, device=device)[None, :],\n                caches,\n            )\n\n        if len(tokens) == 0:\n            breakpoint()\n\n        input_token[-1] = tokens[-1]\n        graph.replay()\n\n        # decide next token on rank‑0\n        next_tok = sample_next_token(logits, temperature=temperature)\n\n        return next_tok\n\n    return infer_next_token\n\n\ndef setup_model(checkpoint: str) -> Callable[[list[int], float], int]:\n    model, device = load_model(checkpoint)\n    infer_next_token = get_infer_next_token(model, device)\n    return infer_next_token\n"
  },
  {
    "path": "gpt_oss/responses_api/inference/vllm.py",
    "content": "\"\"\"\nNOTE: this is not the most efficient way to use vLLM. It's a simple implementation that infers \none token at a time to mimic the behavior of the Triton implementation. \n\"\"\"\n\nimport os\nfrom typing import Callable, List, Optional\n\n# vLLM imports\nfrom vllm import LLM, SamplingParams\nfrom vllm.inputs import TokensPrompt\n\nDEFAULT_TEMPERATURE = 0.0\nTP = os.environ.get(\"TP\", 2)\n\ndef load_model(checkpoint: str):\n    \"\"\"\n    Create the vLLM engine. We enable prefix caching so repeated prefixes\n    across calls can reuse KV cache for faster prefill.\n    \"\"\"\n\n    llm = LLM(\n        model=checkpoint,\n        tensor_parallel_size=TP,          # set >1 if you want TP across GPUs\n        enable_prefix_caching=True,      # reuse KV for shared prefixes\n        disable_log_stats=True,        # uncomment to quiet logs\n    )\n\n    return llm\n\n\ndef get_infer_next_token(llm: LLM):\n    \"\"\"\n    Return a callable with the same shape as your original:\n      infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int\n\n    Implementation detail:\n      - We issue a single-token generation with TokensPrompt(prompt_token_ids=tokens).\n      - vLLM handles sampling (temperature=0 => greedy).\n      - With enable_prefix_caching=True, the shared prefix prefill can be reused\n        across calls that share the same prefix.\n    \"\"\"\n\n    # Maintain compatibility with your previous closure signature.\n    def infer_next_token(\n        tokens: List[int],\n        temperature: float = DEFAULT_TEMPERATURE,\n        new_request: bool = False,  # kept for interface compatibility; unused here\n    ) -> int:\n        if not tokens:\n            raise ValueError(\"tokens must contain at least one input token id\")\n\n        sampling = SamplingParams(\n            temperature=float(temperature),\n            max_tokens=1,            # we only want the next token\n            n=1,                     # single continuation\n            # You can expose/enable more controls here (top_p, top_k, etc.)\n        )\n\n        # Provide token IDs directly (no re-tokenization).\n        outputs = llm.generate(\n            TokensPrompt(prompt_token_ids=tokens),\n            sampling_params=sampling,\n        )\n\n        if not outputs or not outputs[0].outputs:\n            raise RuntimeError(\"vLLM returned empty outputs\")\n\n        gen = outputs[0].outputs[0]\n        if not gen.token_ids:\n            # If the model immediately finished (e.g., EOS), decide how you'd like\n            # to signal that. Here we raise; you could also return an EOS id.\n            raise RuntimeError(\"No next token was generated (possibly EOS).\")\n\n        next_tok = int(gen.token_ids[0])\n        return next_tok\n\n    return infer_next_token\n\n\ndef setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:\n    llm = load_model(checkpoint)\n    infer_next_token = get_infer_next_token(llm)\n    return infer_next_token\n"
  },
  {
    "path": "gpt_oss/responses_api/serve.py",
    "content": "# torchrun --nproc-per-node=4 serve.py\n\nimport argparse\n\nimport uvicorn\nfrom openai_harmony import (\n    HarmonyEncodingName,\n    load_harmony_encoding,\n)\n\nfrom .api_server import create_api_server\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Responses API server\")\n    parser.add_argument(\n        \"--checkpoint\",\n        metavar=\"FILE\",\n        type=str,\n        help=\"Path to the SafeTensors checkpoint\",\n        default=\"~/model\",\n        required=False,\n    )\n    parser.add_argument(\n        \"--port\",\n        metavar=\"PORT\",\n        type=int,\n        default=8000,\n        help=\"Port to run the server on\",\n    )\n    parser.add_argument(\n        \"--inference-backend\",\n        metavar=\"BACKEND\",\n        type=str,\n        help=\"Inference backend to use\",\n        # default to metal on macOS, triton on other platforms\n        default=\"metal\" if __import__(\"platform\").system() == \"Darwin\" else \"triton\",\n    )\n    args = parser.parse_args()\n\n    if args.inference_backend == \"triton\":\n        from .inference.triton import setup_model\n    elif args.inference_backend == \"stub\":\n        from .inference.stub import setup_model\n    elif args.inference_backend == \"metal\":\n        from .inference.metal import setup_model\n    elif args.inference_backend == \"ollama\":\n        from .inference.ollama import setup_model\n    elif args.inference_backend == \"vllm\":\n        from .inference.vllm import setup_model\n    elif args.inference_backend == \"transformers\":\n        from .inference.transformers import setup_model\n    else:\n        raise ValueError(f\"Invalid inference backend: {args.inference_backend}\")\n\n    encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n\n    infer_next_token = setup_model(args.checkpoint)\n    uvicorn.run(create_api_server(infer_next_token, encoding), port=args.port)\n"
  },
  {
    "path": "gpt_oss/responses_api/types.py",
    "content": "from typing import Any, Dict, Literal, Optional, Union\n\nfrom openai_harmony import ReasoningEffort\nfrom pydantic import BaseModel, ConfigDict\n\nMODEL_IDENTIFIER = \"gpt-oss-120b\"\nDEFAULT_TEMPERATURE = 0.0\nREASONING_EFFORT = ReasoningEffort.LOW\nDEFAULT_MAX_OUTPUT_TOKENS = 131072\n\n\nclass UrlCitation(BaseModel):\n    type: Literal[\"url_citation\"]\n    end_index: int\n    start_index: int\n    url: str\n    title: str\n\n\nclass TextContentItem(BaseModel):\n    type: Union[Literal[\"text\"], Literal[\"input_text\"], Literal[\"output_text\"]]\n    text: str\n    status: Optional[str] = \"completed\"\n    annotations: Optional[list[UrlCitation]] = None\n\n\nclass SummaryTextContentItem(BaseModel):\n    # using summary for compatibility with the existing API\n    type: Literal[\"summary_text\"]\n    text: str\n\n\nclass ReasoningTextContentItem(BaseModel):\n    type: Literal[\"reasoning_text\"]\n    text: str\n\n\nclass ReasoningItem(BaseModel):\n    id: str = \"rs_1234\"\n    type: Literal[\"reasoning\"]\n    summary: list[SummaryTextContentItem]\n    content: Optional[list[ReasoningTextContentItem]] = []\n\n\nclass Item(BaseModel):\n    id: Optional[str] = None\n    type: Optional[Literal[\"message\"]] = \"message\"\n    role: Literal[\"user\", \"assistant\", \"system\"]\n    content: Union[list[TextContentItem], str]\n    status: Union[Literal[\"in_progress\", \"completed\", \"incomplete\"], None] = None\n\n\nclass FunctionCallItem(BaseModel):\n    type: Literal[\"function_call\"]\n    name: str\n    arguments: str\n    status: Literal[\"in_progress\", \"completed\", \"incomplete\"] = \"completed\"\n    id: str = \"fc_1234\"\n    call_id: str = \"call_1234\"\n\n\nclass FunctionCallOutputItem(BaseModel):\n    type: Literal[\"function_call_output\"]\n    call_id: str = \"call_1234\"\n    output: str\n\n\nclass WebSearchActionSearch(BaseModel):\n    type: Literal[\"search\"]\n    query: Optional[str] = None\n\n\nclass WebSearchActionOpenPage(BaseModel):\n    type: Literal[\"open_page\"]\n    url: Optional[str] = None\n\n\nclass WebSearchActionFind(BaseModel):\n    type: Literal[\"find\"]\n    pattern: Optional[str] = None\n    url: Optional[str] = None\n\n\nclass WebSearchCallItem(BaseModel):\n    type: Literal[\"web_search_call\"]\n    id: str = \"ws_1234\"\n    status: Literal[\"in_progress\", \"completed\", \"incomplete\"] = \"completed\"\n    action: Union[WebSearchActionSearch, WebSearchActionOpenPage, WebSearchActionFind]\n\n\nclass CodeInterpreterOutputLogs(BaseModel):\n    type: Literal[\"logs\"]\n    logs: str\n\n\nclass CodeInterpreterOutputImage(BaseModel):\n    type: Literal[\"image\"]\n    url: str\n\n\nclass CodeInterpreterCallItem(BaseModel):\n    type: Literal[\"code_interpreter_call\"]\n    id: str = \"ci_1234\"\n    status: Literal[\n        \"in_progress\",\n        \"completed\",\n        \"incomplete\",\n        \"interpreting\",\n        \"failed\",\n    ] = \"completed\"\n    code: Optional[str] = None\n    container_id: Optional[str] = None\n    outputs: Optional[\n        list[Union[CodeInterpreterOutputLogs, CodeInterpreterOutputImage]]\n    ] = None\n\n\nclass Error(BaseModel):\n    code: str\n    message: str\n\n\nclass IncompleteDetails(BaseModel):\n    reason: str\n\n\nclass Usage(BaseModel):\n    input_tokens: int\n    output_tokens: int\n    total_tokens: int\n\n\nclass FunctionToolDefinition(BaseModel):\n    type: Literal[\"function\"]\n    name: str\n    parameters: dict  # this should be typed stricter if you add strict mode\n    strict: bool = False  # change this if you support strict mode\n    description: Optional[str] = \"\"\n\n\nclass BrowserToolConfig(BaseModel):\n    model_config = ConfigDict(extra='allow')\n    type: Literal[\"browser_search\"] | Literal[\"web_search\"]\n\n\nclass CodeInterpreterToolConfig(BaseModel):\n    type: Literal[\"code_interpreter\"]\n\n\nclass ReasoningConfig(BaseModel):\n    effort: Literal[\"low\", \"medium\", \"high\"] = REASONING_EFFORT\n\n\nclass ResponsesRequest(BaseModel):\n    instructions: Optional[str] = None\n    max_output_tokens: Optional[int] = DEFAULT_MAX_OUTPUT_TOKENS\n    input: Union[\n        str,\n        list[\n            Union[\n                Item,\n                ReasoningItem,\n                FunctionCallItem,\n                FunctionCallOutputItem,\n                WebSearchCallItem,\n                CodeInterpreterCallItem,\n            ]\n        ],\n    ]\n    model: Optional[str] = MODEL_IDENTIFIER\n    stream: Optional[bool] = False\n    tools: Optional[\n        list[\n            Union[FunctionToolDefinition, BrowserToolConfig, CodeInterpreterToolConfig]\n        ]\n    ] = []\n    reasoning: Optional[ReasoningConfig] = ReasoningConfig()\n    metadata: Optional[Dict[str, Any]] = {}\n    tool_choice: Optional[Literal[\"auto\", \"none\"]] = \"auto\"\n    parallel_tool_calls: Optional[bool] = False\n    store: Optional[bool] = False\n    previous_response_id: Optional[str] = None\n    temperature: Optional[float] = DEFAULT_TEMPERATURE\n    include: Optional[list[str]] = None\n\n\nclass ResponseObject(BaseModel):\n    output: list[\n        Union[\n            Item,\n            ReasoningItem,\n            FunctionCallItem,\n            FunctionCallOutputItem,\n            WebSearchCallItem,\n            CodeInterpreterCallItem,\n        ]\n    ]\n    created_at: int\n    usage: Optional[Usage] = None\n    status: Literal[\"completed\", \"failed\", \"incomplete\", \"in_progress\"] = \"in_progress\"\n    background: None = None\n    error: Optional[Error] = None\n    incomplete_details: Optional[IncompleteDetails] = None\n    instructions: Optional[str] = None\n    max_output_tokens: Optional[int] = None\n    max_tool_calls: Optional[int] = None\n    metadata: Optional[Dict[str, Any]] = {}\n    model: Optional[str] = MODEL_IDENTIFIER\n    parallel_tool_calls: Optional[bool] = False\n    previous_response_id: Optional[str] = None\n    id: Optional[str] = \"resp_1234\"\n    object: Optional[str] = \"response\"\n    text: Optional[Dict[str, Any]] = None\n    tool_choice: Optional[str] = \"auto\"\n    top_p: Optional[int] = 1\n"
  },
  {
    "path": "gpt_oss/responses_api/utils.py",
    "content": "import time\n\nfake_tokens = [\n    200005,\n    35644,\n    200008,\n    23483,\n    316,\n    1199,\n    1114,\n    717,\n    170154,\n    13,\n    200007,\n    200006,\n    173781,\n    200005,\n    35644,\n    316,\n    28,\n    44580,\n    775,\n    170154,\n    464,\n    91,\n    542,\n    141043,\n    91,\n    29,\n    4108,\n    200008,\n    10848,\n    7693,\n    7534,\n    28499,\n    18826,\n    18583,\n    200012,\n]\nfake_tokens = [\n    200005,\n    35644,\n    200008,\n    1844,\n    31064,\n    25,\n    392,\n    4827,\n    382,\n    220,\n    17,\n    659,\n    220,\n    17,\n    16842,\n    12295,\n    81645,\n    13,\n    51441,\n    6052,\n    13,\n    200007,\n    200006,\n    173781,\n    200005,\n    17196,\n    200008,\n    17,\n    659,\n    220,\n    17,\n    314,\n    220,\n    19,\n    13,\n    9552,\n    238,\n    242,\n    200002,\n]\n# fake_tokens = [200005, 35644, 200008, 976, 1825, 31064, 25, 392, 25216, 29400, 290, 11122, 306, 52768, 2117, 16842, 1416, 1309, 316, 2281, 198, 68, 290, 2208, 11122, 13, 1416, 679, 261, 1114, 717, 170154, 484, 44390, 261, 5100, 1621, 26, 581, 1757, 2005, 198, 75, 480, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 7801, 4733, 290, 11122, 5377, 484, 290, 1114, 7377, 13, 1416, 1309, 260, 198, 78, 1199, 290, 1114, 4584, 364, 58369, 2421, 717, 170154, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 200007, 200006, 173781, 200005, 12606, 815, 260, 198, 78, 28, 117673, 3490]\n# fake_tokens = [\n#     198,\n#     200005,\n#     35644,\n#     200008,\n#     23483,\n#     316,\n#     1199,\n#     1114,\n#     717,\n#     170154,\n#     13,\n#     200007,\n#     200006,\n#     173781,\n#     200005,\n#     12606,\n#     815,\n#     316,\n#     32455,\n#     106847,\n#     316,\n#     28,\n#     44580,\n#     775,\n#     170154,\n#     464,\n#     91,\n#     542,\n#     141043,\n#     91,\n#     29,\n#     4108,\n#     200008,\n#     10848,\n#     7693,\n#     7534,\n#     28499,\n#     18826,\n#     18583,\n#     200012,\n#     198,\n# ]\n\ntoken_queue = fake_tokens.copy()\n\n\ndef stub_infer_next_token(tokens: list[int], temperature: float = 0.0) -> int:\n    global token_queue\n    next_tok = token_queue.pop(0)\n    if len(token_queue) == 0:\n        token_queue = fake_tokens.copy()\n    time.sleep(0.1)\n    return next_tok\n"
  },
  {
    "path": "gpt_oss/tokenizer.py",
    "content": "import tiktoken\n\ndef get_tokenizer():\n    o200k_base = tiktoken.get_encoding(\"o200k_base\")\n    tokenizer = tiktoken.Encoding(\n        name=\"o200k_harmony\",\n        pat_str=o200k_base._pat_str,\n        mergeable_ranks=o200k_base._mergeable_ranks,\n        special_tokens={\n            **o200k_base._special_tokens,\n            \"<|startoftext|>\": 199998,\n            \"<|endoftext|>\": 199999,\n            \"<|reserved_200000|>\": 200000,\n            \"<|reserved_200001|>\": 200001,\n            \"<|return|>\": 200002,\n            \"<|constrain|>\": 200003,\n            \"<|reserved_200004|>\": 200004,\n            \"<|channel|>\": 200005,\n            \"<|start|>\": 200006,\n            \"<|end|>\": 200007,\n            \"<|message|>\": 200008,\n            \"<|reserved_200009|>\": 200009,\n            \"<|reserved_200010|>\": 200010,\n            \"<|reserved_200011|>\": 200011,\n            \"<|call|>\": 200012,\n        } | {\n            f\"<|reserved_{i}|>\": i for i in range(200013, 201088)\n        },\n    )\n    return tokenizer\n"
  },
  {
    "path": "gpt_oss/tools/__init__.py",
    "content": ""
  },
  {
    "path": "gpt_oss/tools/apply_patch.md",
    "content": "When requested to perform coding-related tasks, you MUST adhere to the following criteria when executing the task:\n\n- Use `apply_patch` to edit files.\n- If completing the user's task requires writing or modifying files:\n  - Your code and final answer should follow these _CODING GUIDELINES_:\n    - Avoid unneeded complexity in your solution. Minimize program size.\n    - Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.\n    - NEVER add copyright or license headers unless specifically requested.\n- Never implement function stubs. Provide complete working implementations.\n\n§ `apply_patch` Specification\n\nYour patch language is a stripped‑down, file‑oriented diff format designed to be easy to parse and safe to apply. You can think of it as a high‑level envelope:\n\n*** Begin Patch\n[ one or more file sections ]\n*** End Patch\n\nWithin that envelope, you get a sequence of file operations.\nYou MUST include a header to specify the action you are taking.\nEach operation starts with one of three headers:\n\n*** Add File: <path> - create a new file. Every following line is a + line (the initial contents).\n*** Delete File: <path> - remove an existing file. Nothing follows.\n*** Update File: <path> - patch an existing file in place (optionally with a rename).\n\nMay be immediately followed by *** Move to: <new path> if you want to rename the file.\nThen one or more “hunks”, each introduced by @@ (optionally followed by a hunk header).\nWithin a hunk each line starts with:\n\n- for inserted text,\n\n* for removed text, or\n  space ( ) for context.\n  At the end of a truncated hunk you can emit *** End of File.\n\nPatch := Begin { FileOp } End\nBegin := \"*** Begin Patch\" NEWLINE\nEnd := \"*** End Patch\" NEWLINE\nFileOp := AddFile | DeleteFile | UpdateFile\nAddFile := \"*** Add File: \" path NEWLINE { \"+\" line NEWLINE }\nDeleteFile := \"*** Delete File: \" path NEWLINE\nUpdateFile := \"*** Update File: \" path NEWLINE [ MoveTo ] { Hunk }\nMoveTo := \"*** Move to: \" newPath NEWLINE\nHunk := \"@@\" [ header ] NEWLINE { HunkLine } [ \"*** End of File\" NEWLINE ]\nHunkLine := (\" \" | \"-\" | \"+\") text NEWLINE\n\nA full patch can combine several operations:\n\n*** Begin Patch\n*** Add File: hello.txt\n+Hello world\n*** Update File: src/app.py\n*** Move to: src/main.py\n@@ def greet():\n-print(\"Hi\")\n+print(\"Hello, world!\")\n*** Delete File: obsolete.txt\n*** End Patch\n\nIt is important to remember:\n\n- You must include a header with your intended action (Add/Delete/Update)\n- You must prefix new lines with `+` even when creating a new file\n"
  },
  {
    "path": "gpt_oss/tools/apply_patch.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\nA self-contained **pure-Python 3.9+** utility for applying human-readable\n“pseudo-diff” patch files to a collection of text files.\n\nSource: https://cookbook.openai.com/examples/gpt4-1_prompting_guide\n\"\"\"\n\nfrom __future__ import annotations\n\nimport pathlib\nfrom dataclasses import dataclass, field\nfrom enum import Enum\nfrom typing import (\n    Callable,\n    Dict,\n    List,\n    Optional,\n    Tuple,\n    Union,\n)\n\n\n# --------------------------------------------------------------------------- #\n#  Domain objects\n# --------------------------------------------------------------------------- #\nclass ActionType(str, Enum):\n    ADD = \"add\"\n    DELETE = \"delete\"\n    UPDATE = \"update\"\n\n\n@dataclass\nclass FileChange:\n    type: ActionType\n    old_content: Optional[str] = None\n    new_content: Optional[str] = None\n    move_path: Optional[str] = None\n\n\n@dataclass\nclass Commit:\n    changes: Dict[str, FileChange] = field(default_factory=dict)\n\n\n# --------------------------------------------------------------------------- #\n#  Exceptions\n# --------------------------------------------------------------------------- #\nclass DiffError(ValueError):\n    \"\"\"Any problem detected while parsing or applying a patch.\"\"\"\n\n\n# --------------------------------------------------------------------------- #\n#  Helper dataclasses used while parsing patches\n# --------------------------------------------------------------------------- #\n@dataclass\nclass Chunk:\n    orig_index: int = -1\n    del_lines: List[str] = field(default_factory=list)\n    ins_lines: List[str] = field(default_factory=list)\n\n\n@dataclass\nclass PatchAction:\n    type: ActionType\n    new_file: Optional[str] = None\n    chunks: List[Chunk] = field(default_factory=list)\n    move_path: Optional[str] = None\n\n\n@dataclass\nclass Patch:\n    actions: Dict[str, PatchAction] = field(default_factory=dict)\n\n\n# --------------------------------------------------------------------------- #\n#  Patch text parser\n# --------------------------------------------------------------------------- #\n@dataclass\nclass Parser:\n    current_files: Dict[str, str]\n    lines: List[str]\n    index: int = 0\n    patch: Patch = field(default_factory=Patch)\n    fuzz: int = 0\n\n    # ------------- low-level helpers -------------------------------------- #\n    def _cur_line(self) -> str:\n        if self.index >= len(self.lines):\n            raise DiffError(\"Unexpected end of input while parsing patch\")\n        return self.lines[self.index]\n\n    @staticmethod\n    def _norm(line: str) -> str:\n        \"\"\"Strip CR so comparisons work for both LF and CRLF input.\"\"\"\n        return line.rstrip(\"\\r\")\n\n    # ------------- scanning convenience ----------------------------------- #\n    def is_done(self, prefixes: Optional[Tuple[str, ...]] = None) -> bool:\n        if self.index >= len(self.lines):\n            return True\n        if (\n            prefixes\n            and len(prefixes) > 0\n            and self._norm(self._cur_line()).startswith(prefixes)\n        ):\n            return True\n        return False\n\n    def startswith(self, prefix: Union[str, Tuple[str, ...]]) -> bool:\n        return self._norm(self._cur_line()).startswith(prefix)\n\n    def read_str(self, prefix: str) -> str:\n        \"\"\"\n        Consume the current line if it starts with *prefix* and return the text\n        **after** the prefix.  Raises if prefix is empty.\n        \"\"\"\n        if prefix == \"\":\n            raise ValueError(\"read_str() requires a non-empty prefix\")\n        if self._norm(self._cur_line()).startswith(prefix):\n            text = self._cur_line()[len(prefix) :]\n            self.index += 1\n            return text\n        return \"\"\n\n    def read_line(self) -> str:\n        \"\"\"Return the current raw line and advance.\"\"\"\n        line = self._cur_line()\n        self.index += 1\n        return line\n\n    # ------------- public entry point -------------------------------------- #\n    def parse(self) -> None:\n        while not self.is_done((\"*** End Patch\",)):\n            # ---------- UPDATE ---------- #\n            path = self.read_str(\"*** Update File: \")\n            if path:\n                if path in self.patch.actions:\n                    raise DiffError(f\"Duplicate update for file: {path}\")\n                move_to = self.read_str(\"*** Move to: \")\n                if path not in self.current_files:\n                    raise DiffError(f\"Update File Error - missing file: {path}\")\n                text = self.current_files[path]\n                action = self._parse_update_file(text)\n                action.move_path = move_to or None\n                self.patch.actions[path] = action\n                continue\n\n            # ---------- DELETE ---------- #\n            path = self.read_str(\"*** Delete File: \")\n            if path:\n                if path in self.patch.actions:\n                    raise DiffError(f\"Duplicate delete for file: {path}\")\n                if path not in self.current_files:\n                    raise DiffError(f\"Delete File Error - missing file: {path}\")\n                self.patch.actions[path] = PatchAction(type=ActionType.DELETE)\n                continue\n\n            # ---------- ADD ---------- #\n            path = self.read_str(\"*** Add File: \")\n            if path:\n                if path in self.patch.actions:\n                    raise DiffError(f\"Duplicate add for file: {path}\")\n                if path in self.current_files:\n                    raise DiffError(f\"Add File Error - file already exists: {path}\")\n                self.patch.actions[path] = self._parse_add_file()\n                continue\n\n            raise DiffError(f\"Unknown line while parsing: {self._cur_line()}\")\n\n        if not self.startswith(\"*** End Patch\"):\n            raise DiffError(\"Missing *** End Patch sentinel\")\n        self.index += 1  # consume sentinel\n\n    # ------------- section parsers ---------------------------------------- #\n    def _parse_update_file(self, text: str) -> PatchAction:\n        action = PatchAction(type=ActionType.UPDATE)\n        lines = text.split(\"\\n\")\n        index = 0\n        while not self.is_done(\n            (\n                \"*** End Patch\",\n                \"*** Update File:\",\n                \"*** Delete File:\",\n                \"*** Add File:\",\n                \"*** End of File\",\n            )\n        ):\n            def_str = self.read_str(\"@@ \")\n            section_str = \"\"\n            if not def_str and self._norm(self._cur_line()) == \"@@\":\n                section_str = self.read_line()\n\n            if not (def_str or section_str or index == 0):\n                raise DiffError(f\"Invalid line in update section:\\n{self._cur_line()}\")\n\n            if def_str.strip():\n                found = False\n                if def_str not in lines[:index]:\n                    for i, s in enumerate(lines[index:], index):\n                        if s == def_str:\n                            index = i + 1\n                            found = True\n                            break\n                if not found and def_str.strip() not in [\n                    s.strip() for s in lines[:index]\n                ]:\n                    for i, s in enumerate(lines[index:], index):\n                        if s.strip() == def_str.strip():\n                            index = i + 1\n                            self.fuzz += 1\n                            found = True\n                            break\n\n            next_ctx, chunks, end_idx, eof = peek_next_section(self.lines, self.index)\n            new_index, fuzz = find_context(lines, next_ctx, index, eof)\n            if new_index == -1:\n                ctx_txt = \"\\n\".join(next_ctx)\n                raise DiffError(\n                    f\"Invalid {'EOF ' if eof else ''}context at {index}:\\n{ctx_txt}\"\n                )\n            self.fuzz += fuzz\n            for ch in chunks:\n                ch.orig_index += new_index\n                action.chunks.append(ch)\n            index = new_index + len(next_ctx)\n            self.index = end_idx\n        return action\n\n    def _parse_add_file(self) -> PatchAction:\n        lines: List[str] = []\n        while not self.is_done(\n            (\"*** End Patch\", \"*** Update File:\", \"*** Delete File:\", \"*** Add File:\")\n        ):\n            s = self.read_line()\n            if not s.startswith(\"+\"):\n                raise DiffError(f\"Invalid Add File line (missing '+'): {s}\")\n            lines.append(s[1:])  # strip leading '+'\n        return PatchAction(type=ActionType.ADD, new_file=\"\\n\".join(lines))\n\n\n# --------------------------------------------------------------------------- #\n#  Helper functions\n# --------------------------------------------------------------------------- #\ndef find_context_core(\n    lines: List[str], context: List[str], start: int\n) -> Tuple[int, int]:\n    if not context:\n        return start, 0\n\n    for i in range(start, len(lines)):\n        if lines[i : i + len(context)] == context:\n            return i, 0\n    for i in range(start, len(lines)):\n        if [s.rstrip() for s in lines[i : i + len(context)]] == [\n            s.rstrip() for s in context\n        ]:\n            return i, 1\n    for i in range(start, len(lines)):\n        if [s.strip() for s in lines[i : i + len(context)]] == [\n            s.strip() for s in context\n        ]:\n            return i, 100\n    return -1, 0\n\n\ndef find_context(\n    lines: List[str], context: List[str], start: int, eof: bool\n) -> Tuple[int, int]:\n    if eof:\n        new_index, fuzz = find_context_core(lines, context, len(lines) - len(context))\n        if new_index != -1:\n            return new_index, fuzz\n        new_index, fuzz = find_context_core(lines, context, start)\n        return new_index, fuzz + 10_000\n    return find_context_core(lines, context, start)\n\n\ndef peek_next_section(\n    lines: List[str], index: int\n) -> Tuple[List[str], List[Chunk], int, bool]:\n    old: List[str] = []\n    del_lines: List[str] = []\n    ins_lines: List[str] = []\n    chunks: List[Chunk] = []\n    mode = \"keep\"\n    orig_index = index\n\n    while index < len(lines):\n        s = lines[index]\n        if s.startswith(\n            (\n                \"@@\",\n                \"*** End Patch\",\n                \"*** Update File:\",\n                \"*** Delete File:\",\n                \"*** Add File:\",\n                \"*** End of File\",\n            )\n        ):\n            break\n        if s == \"***\":\n            break\n        if s.startswith(\"***\"):\n            raise DiffError(f\"Invalid Line: {s}\")\n        index += 1\n\n        last_mode = mode\n        if s == \"\":\n            s = \" \"\n        if s[0] == \"+\":\n            mode = \"add\"\n        elif s[0] == \"-\":\n            mode = \"delete\"\n        elif s[0] == \" \":\n            mode = \"keep\"\n        else:\n            raise DiffError(f\"Invalid Line: {s}\")\n        s = s[1:]\n\n        if mode == \"keep\" and last_mode != mode:\n            if ins_lines or del_lines:\n                chunks.append(\n                    Chunk(\n                        orig_index=len(old) - len(del_lines),\n                        del_lines=del_lines,\n                        ins_lines=ins_lines,\n                    )\n                )\n            del_lines, ins_lines = [], []\n\n        if mode == \"delete\":\n            del_lines.append(s)\n            old.append(s)\n        elif mode == \"add\":\n            ins_lines.append(s)\n        elif mode == \"keep\":\n            old.append(s)\n\n    if ins_lines or del_lines:\n        chunks.append(\n            Chunk(\n                orig_index=len(old) - len(del_lines),\n                del_lines=del_lines,\n                ins_lines=ins_lines,\n            )\n        )\n\n    if index < len(lines) and lines[index] == \"*** End of File\":\n        index += 1\n        return old, chunks, index, True\n\n    if index == orig_index:\n        raise DiffError(\"Nothing in this section\")\n    return old, chunks, index, False\n\n\n# --------------------------------------------------------------------------- #\n#  Patch → Commit and Commit application\n# --------------------------------------------------------------------------- #\ndef _get_updated_file(text: str, action: PatchAction, path: str) -> str:\n    if action.type is not ActionType.UPDATE:\n        raise DiffError(\"_get_updated_file called with non-update action\")\n    orig_lines = text.split(\"\\n\")\n    dest_lines: List[str] = []\n    orig_index = 0\n\n    for chunk in action.chunks:\n        if chunk.orig_index > len(orig_lines):\n            raise DiffError(\n                f\"{path}: chunk.orig_index {chunk.orig_index} exceeds file length\"\n            )\n        if orig_index > chunk.orig_index:\n            raise DiffError(\n                f\"{path}: overlapping chunks at {orig_index} > {chunk.orig_index}\"\n            )\n\n        dest_lines.extend(orig_lines[orig_index : chunk.orig_index])\n        orig_index = chunk.orig_index\n\n        dest_lines.extend(chunk.ins_lines)\n        orig_index += len(chunk.del_lines)\n\n    dest_lines.extend(orig_lines[orig_index:])\n    return \"\\n\".join(dest_lines)\n\n\ndef patch_to_commit(patch: Patch, orig: Dict[str, str]) -> Commit:\n    commit = Commit()\n    for path, action in patch.actions.items():\n        if action.type is ActionType.DELETE:\n            commit.changes[path] = FileChange(\n                type=ActionType.DELETE, old_content=orig[path]\n            )\n        elif action.type is ActionType.ADD:\n            if action.new_file is None:\n                raise DiffError(\"ADD action without file content\")\n            commit.changes[path] = FileChange(\n                type=ActionType.ADD, new_content=action.new_file\n            )\n        elif action.type is ActionType.UPDATE:\n            new_content = _get_updated_file(orig[path], action, path)\n            commit.changes[path] = FileChange(\n                type=ActionType.UPDATE,\n                old_content=orig[path],\n                new_content=new_content,\n                move_path=action.move_path,\n            )\n    return commit\n\n\n# --------------------------------------------------------------------------- #\n#  User-facing helpers\n# --------------------------------------------------------------------------- #\ndef text_to_patch(text: str, orig: Dict[str, str]) -> Tuple[Patch, int]:\n    lines = text.splitlines()  # preserves blank lines, no strip()\n    if (\n        len(lines) < 2\n        or not Parser._norm(lines[0]).startswith(\"*** Begin Patch\")\n        or Parser._norm(lines[-1]) != \"*** End Patch\"\n    ):\n        raise DiffError(\"Invalid patch text - missing sentinels\")\n\n    parser = Parser(current_files=orig, lines=lines, index=1)\n    parser.parse()\n    return parser.patch, parser.fuzz\n\n\ndef identify_files_needed(text: str) -> List[str]:\n    lines = text.splitlines()\n    return [\n        line[len(\"*** Update File: \") :]\n        for line in lines\n        if line.startswith(\"*** Update File: \")\n    ] + [\n        line[len(\"*** Delete File: \") :]\n        for line in lines\n        if line.startswith(\"*** Delete File: \")\n    ]\n\n\ndef identify_files_added(text: str) -> List[str]:\n    lines = text.splitlines()\n    return [\n        line[len(\"*** Add File: \") :]\n        for line in lines\n        if line.startswith(\"*** Add File: \")\n    ]\n\n\n# --------------------------------------------------------------------------- #\n#  File-system helpers\n# --------------------------------------------------------------------------- #\ndef load_files(paths: List[str], open_fn: Callable[[str], str]) -> Dict[str, str]:\n    return {path: open_fn(path) for path in paths}\n\n\ndef apply_commit(\n    commit: Commit,\n    write_fn: Callable[[str, str], None],\n    remove_fn: Callable[[str], None],\n) -> None:\n    for path, change in commit.changes.items():\n        if change.type is ActionType.DELETE:\n            remove_fn(path)\n        elif change.type is ActionType.ADD:\n            if change.new_content is None:\n                raise DiffError(f\"ADD change for {path} has no content\")\n            write_fn(path, change.new_content)\n        elif change.type is ActionType.UPDATE:\n            if change.new_content is None:\n                raise DiffError(f\"UPDATE change for {path} has no new content\")\n            target = change.move_path or path\n            write_fn(target, change.new_content)\n            if change.move_path:\n                remove_fn(path)\n\n\ndef open_file(path: str) -> str:\n    with open(path, \"rt\", encoding=\"utf-8\") as fh:\n        return fh.read()\n\n\ndef write_file(path: str, content: str) -> None:\n    target = pathlib.Path(path)\n    target.parent.mkdir(parents=True, exist_ok=True)\n    with target.open(\"wt\", encoding=\"utf-8\") as fh:\n        fh.write(content)\n\n\ndef remove_file(path: str) -> None:\n    pathlib.Path(path).unlink(missing_ok=True)\n\n\n\ndef apply_patch(\n    text: str,\n    open_fn: Callable[[str], str] = open_file,\n    write_fn: Callable[[str, str], None] = write_file, \n    remove_fn: Callable[[str], None] = remove_file,\n) -> str:\n    if not text.startswith(\"*** Begin Patch\"):\n        raise DiffError(\"Patch text must start with *** Begin Patch\")\n    paths = identify_files_needed(text)\n    orig = load_files(paths, open_fn)\n    patch, _fuzz = text_to_patch(text, orig)\n    commit = patch_to_commit(patch, orig)\n    apply_commit(commit, write_fn, remove_fn)\n    return \"Done!\"\n\n\ndef main() -> None:\n    import sys\n\n    patch_text = sys.stdin.read()\n    if not patch_text:\n        print(\"Please pass patch text through stdin\", file=sys.stderr)\n        return\n    try:\n        result = apply_patch(patch_text)\n    except DiffError as exc:\n        print(exc, file=sys.stderr)\n        return\n    print(result)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "gpt_oss/tools/python_docker/docker_tool.py",
    "content": "# Run this before running the tool:\n# $ docker image pull python:3.11\nimport asyncio\nimport contextlib\nimport io\nimport os\nimport queue\nimport subprocess\nimport tarfile\nimport tempfile\nfrom pathlib import Path\nfrom typing import Any, AsyncIterator\n\nimport docker\nfrom openai_harmony import (\n    Author,\n    Content,\n    Message,\n    Role,\n    TextContent,\n    ToolNamespaceConfig,\n)\n\nfrom ..tool import Tool\n\n_docker_client = None\n\nVALID_EXECUTION_BACKENDS = {\n    \"docker\",\n    \"dangerously_use_uv\",\n    \"dangerously_use_local_jupyter\",\n}\n\n_default_backend = os.environ.get(\"PYTHON_EXECUTION_BACKEND\", \"docker\")\nif _default_backend not in VALID_EXECUTION_BACKENDS:\n    _default_backend = \"docker\"\n\nPYTHON_EXECUTION_BACKEND = _default_backend\n\n\ndef call_python_script(script: str) -> str:\n    \"\"\"\n    Call a python script by writing it to a file in the container and executing it.\n    \"\"\"\n    global _docker_client\n    if _docker_client is None:\n        _docker_client = docker.from_env()\n        # pull image `python:3.11` if not present\n        try:\n            _docker_client.images.get(\"python:3.11\")\n        except docker.errors.ImageNotFound:\n            _docker_client.images.pull(\"python:3.11\")\n\n    # 1. Create a temporary tar archive containing the script\n    script_name = \"script.py\"\n    tarstream = io.BytesIO()\n    with tarfile.open(fileobj=tarstream, mode=\"w\") as tar:\n        script_bytes = script.encode(\"utf-8\")\n        tarinfo = tarfile.TarInfo(name=script_name)\n        tarinfo.size = len(script_bytes)\n        tar.addfile(tarinfo, io.BytesIO(script_bytes))\n    tarstream.seek(0)\n\n    # 2. Start the container\n    container = _docker_client.containers.create(\n        \"python:3.11\", command=\"sleep infinity\", detach=True\n    )\n    try:\n        container.start()\n        # 3. Put the script into the container\n        container.put_archive(path=\"/tmp\", data=tarstream.read())\n        # 4. Execute the script\n        exec_result = container.exec_run(f\"python /tmp/{script_name}\")\n        output = exec_result.output.decode(\"utf-8\")\n        if not output.strip():\n            output = \"[WARN] No output available. Use print() to output anything to stdout to receive the output\"\n    finally:\n        container.remove(force=True)\n    return output\n\n\ndef call_python_script_with_uv(script: str) -> str:\n    \"\"\"\n    Call a python script by writing it to a file to a temporary directory\n    and executing it with uv.\n    \"\"\"\n    with tempfile.TemporaryDirectory() as temp_dir:\n        script_path = os.path.join(temp_dir, \"script.py\")\n        with open(script_path, \"w\") as f:\n            f.write(script)\n        exec_result = subprocess.run(\n            [\"uv\", \"run\", \"--no-project\", \"python\", script_path],\n            capture_output=True)\n        return (\n            exec_result.stdout.decode(\"utf-8\")\n            if exec_result.returncode == 0\n            else exec_result.stderr.decode(\"utf-8\")\n        )\n\n\nclass LocalJupyterSession:\n    \"\"\"Stateful helper that proxies execution through a local Jupyter kernel.\"\"\"\n\n    def __init__(\n        self,\n        connection_file: str | None = None,\n        *,\n        timeout: float = 120.0,\n    ) -> None:\n        try:\n            from jupyter_client import BlockingKernelClient, KernelManager\n        except ImportError as exc:  # pragma: no cover - optional dependency\n            raise RuntimeError(\n                \"The dangerously_use_local_jupyter backend requires the jupyter_client package to be installed.\"\n            ) from exc\n\n        self._default_timeout = timeout\n        self._owns_kernel = False\n        self._client: BlockingKernelClient\n        self._km: KernelManager | None = None\n\n        if connection_file:\n            connection_path = Path(connection_file).expanduser()\n            if not connection_path.exists():\n                raise FileNotFoundError(\n                    f\"Cannot find Jupyter connection file at '{connection_path}'.\"\n                )\n            client = BlockingKernelClient()\n            client.load_connection_file(str(connection_path))\n            client.start_channels()\n            # Ensure the connection is ready before executing.\n            client.wait_for_ready(timeout=self._default_timeout)\n            self._client = client\n        else:\n            km = KernelManager()\n            km.start_kernel()\n            client = km.blocking_client()\n            client.start_channels()\n            client.wait_for_ready(timeout=self._default_timeout)\n            self._client = client\n            self._km = km\n            self._owns_kernel = True\n\n    def execute(self, code: str, *, timeout: float | None = None) -> str:\n        \"\"\"Execute code in the kernel, returning combined stdout/stderr output.\"\"\"\n\n        client = self._client\n        effective_timeout = timeout or self._default_timeout\n        msg_id = client.execute(\n            code,\n            store_history=True,\n            allow_stdin=False,\n            stop_on_error=False,\n        )\n\n        stdout_parts: list[str] = []\n        stderr_parts: list[str] = []\n\n        while True:\n            try:\n                msg = client.get_iopub_msg(timeout=effective_timeout)\n            except queue.Empty as exc:\n                raise TimeoutError(\"Timed out waiting for Jupyter kernel output.\") from exc\n\n            if msg.get(\"parent_header\", {}).get(\"msg_id\") != msg_id:\n                continue\n\n            msg_type = msg.get(\"msg_type\")\n            content = msg.get(\"content\", {})\n\n            if msg_type == \"stream\":\n                text = content.get(\"text\", \"\")\n                if content.get(\"name\") == \"stdout\":\n                    stdout_parts.append(text)\n                else:\n                    stderr_parts.append(text)\n            elif msg_type == \"error\":\n                traceback_data = content.get(\"traceback\")\n                if traceback_data:\n                    stderr_parts.append(\"\\n\".join(traceback_data))\n                else:\n                    ename = content.get(\"ename\", \"\")\n                    evalue = content.get(\"evalue\", \"\")\n                    stderr_parts.append(f\"{ename}: {evalue}\".strip())\n            elif msg_type in {\"execute_result\", \"display_data\"}:\n                data = content.get(\"data\", {})\n                text = data.get(\"text/plain\")\n                if text:\n                    stdout_parts.append(text if text.endswith(\"\\n\") else f\"{text}\\n\")\n            elif msg_type == \"status\" and content.get(\"execution_state\") == \"idle\":\n                break\n\n        # Drain the shell channel to capture final execution status.\n        while True:\n            try:\n                reply = client.get_shell_msg(timeout=effective_timeout)\n            except queue.Empty as exc:\n                raise TimeoutError(\n                    \"Timed out waiting for Jupyter kernel execution reply.\"\n                ) from exc\n\n            if reply.get(\"parent_header\", {}).get(\"msg_id\") != msg_id:\n                continue\n\n            reply_content = reply.get(\"content\", {})\n            if reply_content.get(\"status\") == \"error\":\n                traceback_data = reply_content.get(\"traceback\")\n                if traceback_data:\n                    stderr_parts.append(\"\\n\".join(traceback_data))\n                else:\n                    ename = reply_content.get(\"ename\", \"\")\n                    evalue = reply_content.get(\"evalue\", \"\")\n                    stderr_parts.append(f\"{ename}: {evalue}\".strip())\n            break\n\n        stdout = \"\".join(stdout_parts)\n        stderr = \"\".join(stderr_parts)\n\n        if stderr:\n            if stdout:\n                stdout = f\"{stdout.rstrip()}\\n{stderr}\"\n            else:\n                stdout = stderr\n\n        if not stdout.strip():\n            stdout = (\n                \"[WARN] No output available. Use print() to output anything to stdout to \"\n                \"receive the output\"\n            )\n\n        return stdout\n\n    def close(self) -> None:\n        with contextlib.suppress(Exception):\n            self._client.stop_channels()\n\n        if self._owns_kernel and self._km is not None:\n            with contextlib.suppress(Exception):\n                self._km.shutdown_kernel(now=True)\n\n    def __del__(self) -> None:  # pragma: no cover - best-effort cleanup\n        self.close()\n\nclass PythonTool(Tool):\n    def __init__(\n        self,\n        name: str = \"python\",\n        *,\n        execution_backend: str | None = None,\n        local_jupyter_connection_file: str | None = None,\n        local_jupyter_timeout: float = 60.0,\n    ):\n        assert name == \"python\"\n\n        backend = execution_backend or PYTHON_EXECUTION_BACKEND\n        if backend not in VALID_EXECUTION_BACKENDS:\n            raise ValueError(\n                \"execution_backend must be one of: \"\n                + \", \".join(sorted(VALID_EXECUTION_BACKENDS))\n            )\n\n        self._execution_backend = backend\n        self._local_jupyter_connection_file = (\n            local_jupyter_connection_file\n            or os.environ.get(\"PYTHON_LOCAL_JUPYTER_CONNECTION_FILE\")\n        )\n        self._local_jupyter_timeout = local_jupyter_timeout\n\n        self._jupyter_session: LocalJupyterSession | None = None\n        self._execution_lock: asyncio.Lock | None = None\n\n        if self._execution_backend == \"dangerously_use_local_jupyter\":\n            self._execution_lock = asyncio.Lock()\n            self._jupyter_session = LocalJupyterSession(\n                connection_file=self._local_jupyter_connection_file,\n                timeout=self._local_jupyter_timeout,\n            )\n\n    @classmethod\n    def get_tool_name(cls) -> str:\n        return \"python\"\n\n    @property\n    def name(self) -> str:\n        return self.get_tool_name()\n\n    @property\n    def instruction(self) -> str:\n        if self._execution_backend == \"dangerously_use_local_jupyter\":\n            return \"\"\"\nUse this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. Internet access for this session is UNKNOWN. Depends on the cluster.\n            \"\"\".strip()\n\n        return \"\"\"\nUse this tool to execute STATELESS Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\nWhen you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you. You have to use print statements to access the output.\n\nIMPORTANT: Your python environment is not shared between calls. You will have to pass your entire code each time.\n        \"\"\".strip()\n\n    @property\n    def tool_config(self) -> ToolNamespaceConfig:\n        return ToolNamespaceConfig(\n            name=self.get_tool_name(), description=self.instruction, tools=[]\n        )\n\n    def _make_response(\n        self,\n        output: str,\n        channel: str | None = None,\n    ) -> Message:\n        content = TextContent(text=output)\n        return self.make_response(content=content, channel=channel)\n\n    def make_response(\n        self,\n        content: Content,\n        *,\n        metadata: dict[str, Any] | None = None,\n        author: Author | None = None,\n        channel: str | None = None,\n    ) -> Message:\n        tool_name = self.get_tool_name()\n        author = Author(role=Role.TOOL, name=f\"{tool_name}\")\n\n        message = Message(\n            author=author,\n            content=[content],\n        ).with_recipient(\"assistant\")\n\n        if channel:\n            message = message.with_channel(channel)\n\n        return message\n\n    async def _process(self, message: Message) -> AsyncIterator[Message]:\n        script = message.content[0].text\n        channel = message.channel\n\n        if self._execution_backend == \"docker\":\n            output = call_python_script(script)\n        elif self._execution_backend == \"dangerously_use_uv\":\n            output = call_python_script_with_uv(script)\n        elif self._execution_backend == \"dangerously_use_local_jupyter\":\n            assert self._jupyter_session is not None\n            lock = self._execution_lock\n            if lock is not None:\n                async with lock:\n                    try:\n                        output = self._jupyter_session.execute(script)\n                    except TimeoutError as exc:\n                        output = f\"[ERROR] {exc}\"\n            else:\n                try:\n                    output = self._jupyter_session.execute(script)\n                except TimeoutError as exc:\n                    output = f\"[ERROR] {exc}\"\n        else:\n            raise ValueError(\n                f\"Invalid PYTHON_EXECUTION_BACKEND: {self._execution_backend}\"\n            )\n        yield self._make_response(output, channel=channel)\n\n    def close(self) -> None:\n        if self._jupyter_session is not None:\n            self._jupyter_session.close()\n\n    def __del__(self) -> None:  # pragma: no cover - best-effort cleanup\n        self.close()\n"
  },
  {
    "path": "gpt_oss/tools/simple_browser/__init__.py",
    "content": "from .simple_browser_tool import SimpleBrowserTool\nfrom .backend import ExaBackend, YouComBackend\n\n__all__ = [\n    \"SimpleBrowserTool\",\n    \"ExaBackend\",\n    \"YouComBackend\",\n]\n"
  },
  {
    "path": "gpt_oss/tools/simple_browser/backend.py",
    "content": "\"\"\"\nSimple backend for the simple browser tool.\n\"\"\"\n\nimport functools\nimport asyncio\nimport logging\nimport os\nfrom abc import abstractmethod\nfrom importlib.metadata import version\nfrom typing import Callable, ParamSpec, TypeVar\nfrom urllib.parse import quote\n\nimport chz\nfrom aiohttp import ClientSession, ClientTimeout\nfrom tenacity import (\n    after_log,\n    before_sleep_log,\n    retry,\n    retry_if_exception_type,\n    stop_after_attempt,\n    wait_exponential,\n)\n\nfrom .page_contents import (\n    Extract,\n    FetchResult,\n    PageContents,\n    get_domain,\n    process_html,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nVIEW_SOURCE_PREFIX = \"view-source:\"\n\ntry:\n    _GPT_OSS_VERSION = version(\"gpt-oss\")\nexcept Exception:\n    _GPT_OSS_VERSION = \"0.0.8\"  # fallback version\n\n\nclass BackendError(Exception):\n    pass\n\n\nP = ParamSpec(\"P\")\nR = TypeVar(\"R\")\n\n\ndef with_retries(\n    func: Callable[P, R],\n    num_retries: int,\n    max_wait_time: float,\n) -> Callable[P, R]:\n    if num_retries > 0:\n        retry_decorator = retry(\n            stop=stop_after_attempt(num_retries),\n            wait=wait_exponential(\n                multiplier=1,\n                min=2,\n                max=max_wait_time,\n            ),\n            before_sleep=before_sleep_log(logger, logging.INFO),\n            after=after_log(logger, logging.INFO),\n            retry=retry_if_exception_type(Exception),\n        )\n        return retry_decorator(func)\n    else:\n        return func\n\n\ndef maybe_truncate(text: str, num_chars: int = 1024) -> str:\n    if len(text) > num_chars:\n        text = text[: (num_chars - 3)] + \"...\"\n    return text\n\n\n@chz.chz(typecheck=True)\nclass Backend:\n    source: str = chz.field(doc=\"Description of the backend source\")\n\n    @abstractmethod\n    async def search(\n        self,\n        query: str,\n        topn: int,\n        session: ClientSession,\n    ) -> PageContents:\n        pass\n\n    @abstractmethod\n    async def fetch(self, url: str, session: ClientSession) -> PageContents:\n        pass\n\n    async def _post(self, session: ClientSession, endpoint: str, payload: dict) -> dict:\n        headers = {\n            \"x-api-key\": self._get_api_key(),\n            \"user-agent\": f\"gpt-oss/{_GPT_OSS_VERSION}\",\n        }\n        async with session.post(f\"{self.BASE_URL}{endpoint}\", json=payload, headers=headers) as resp:\n            if resp.status != 200:\n                raise BackendError(\n                    f\"{self.__class__.__name__} error {resp.status}: {await resp.text()}\"\n                )\n            return await resp.json()\n\n    async def _get(self, session: ClientSession, endpoint: str, params: dict) -> dict:\n        headers = {\n            \"x-api-key\": self._get_api_key(),\n            \"user-agent\": f\"gpt-oss/{_GPT_OSS_VERSION}\",\n        }\n        async with session.get(f\"{self.BASE_URL}{endpoint}\", params=params, headers=headers) as resp:\n            if resp.status != 200:\n                raise BackendError(\n                    f\"{self.__class__.__name__} error {resp.status}: {await resp.text()}\"\n                )\n            return await resp.json()\n\n\n@chz.chz(typecheck=True)\nclass ExaBackend(Backend):\n    \"\"\"Backend that uses the Exa Search API.\"\"\"\n\n    source: str = chz.field(doc=\"Description of the backend source\")\n    api_key: str | None = chz.field(\n        doc=\"Exa API key. Uses EXA_API_KEY environment variable if not provided.\",\n        default=None,\n    )\n\n    BASE_URL: str = \"https://api.exa.ai\"\n\n    def _get_api_key(self) -> str:\n        key = self.api_key or os.environ.get(\"EXA_API_KEY\")\n        if not key:\n            raise BackendError(\"Exa API key not provided\")\n        return key\n\n\n    async def search(\n        self, query: str, topn: int, session: ClientSession\n    ) -> PageContents:\n        data = await self._post(\n            session,\n            \"/search\",\n            {\"query\": query, \"numResults\": topn, \"contents\": {\"text\": True, \"summary\": True}},\n        )\n        # make a simple HTML page to work with browser format\n        titles_and_urls = [\n            (result[\"title\"], result[\"url\"], result[\"summary\"])\n            for result in data[\"results\"]\n        ]\n        html_page = f\"\"\"\n<html><body>\n<h1>Search Results</h1>\n<ul>\n{\"\".join([f\"<li><a href='{url}'>{title}</a> {summary}</li>\" for title, url, summary in titles_and_urls])}\n</ul>\n</body></html>\n\"\"\"\n\n        return process_html(\n            html=html_page,\n            url=\"\",\n            title=query,\n            display_urls=True,\n            session=session,\n        )\n\n    async def fetch(self, url: str, session: ClientSession) -> PageContents:\n        is_view_source = url.startswith(VIEW_SOURCE_PREFIX)\n        if is_view_source:\n            url = url[len(VIEW_SOURCE_PREFIX) :]\n        data = await self._post(\n            session,\n            \"/contents\",\n            {\"urls\": [url], \"text\": { \"includeHtmlTags\": True }},\n        )\n        results = data.get(\"results\", [])\n        if not results:\n            raise BackendError(f\"No contents returned for {url}\")\n        return process_html(\n            html=results[0].get(\"text\", \"\"),\n            url=url,\n            title=results[0].get(\"title\", \"\"),\n            display_urls=True,\n            session=session,\n        )\n\n@chz.chz(typecheck=True)\nclass YouComBackend(Backend):\n    \"\"\"Backend that uses the You.com Search API.\"\"\"\n\n    source: str = chz.field(doc=\"Description of the backend source\")\n\n    BASE_URL: str = \"https://api.ydc-index.io\"\n\n    def _get_api_key(self) -> str:\n        key = os.environ.get(\"YDC_API_KEY\")\n        if not key:\n            raise BackendError(\"You.com API key not provided\")\n        return key\n\n    \n    async def search(\n        self, query: str, topn: int, session: ClientSession\n    ) -> PageContents:\n        data = await self._get(\n            session,\n            \"/v1/search\",\n            {\"query\": query, \"count\": topn},\n        )\n        # make a simple HTML page to work with browser format\n        web_titles_and_urls, news_titles_and_urls = [], []\n        if \"web\" in data[\"results\"]:\n            web_titles_and_urls = [\n                (result[\"title\"], result[\"url\"], result[\"snippets\"])\n                for result in data[\"results\"][\"web\"]\n            ]\n        if \"news\" in data[\"results\"]:\n            news_titles_and_urls = [\n                (result[\"title\"], result[\"url\"], result[\"description\"])\n                for result in data[\"results\"][\"news\"]\n            ]\n        titles_and_urls = web_titles_and_urls + news_titles_and_urls\n        html_page = f\"\"\"\n<html><body>\n<h1>Search Results</h1>\n<ul>\n{\"\".join([f\"<li><a href='{url}'>{title}</a> {summary}</li>\" for title, url, summary in titles_and_urls])}\n</ul>\n</body></html>\n\"\"\"\n\n        return process_html(\n            html=html_page,\n            url=\"\",\n            title=query,\n            display_urls=True,\n            session=session,\n        )\n\n    async def fetch(self, url: str, session: ClientSession) -> PageContents:\n        is_view_source = url.startswith(VIEW_SOURCE_PREFIX)\n        if is_view_source:\n            url = url[len(VIEW_SOURCE_PREFIX) :]\n        data = await self._post(\n            session,\n            \"/v1/contents\",\n            {\"urls\": [url], \"livecrawl_formats\": \"html\"},\n        )\n        if not data:\n            raise BackendError(f\"No contents returned for {url}\")\n        if \"html\" not in data[0]:\n            raise BackendError(f\"No HTML returned for {url}\")\n        return process_html(\n            html=data[0].get(\"html\", \"\"),\n            url=url,\n            title=data[0].get(\"title\", \"\"),\n            display_urls=True,\n            session=session,\n        )\n\n"
  },
  {
    "path": "gpt_oss/tools/simple_browser/page_contents.py",
    "content": "\"\"\"\nPage contents for the simple browser tool.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport dataclasses\nimport functools\nimport logging\nimport re\nfrom urllib.parse import urljoin, urlparse\n\nimport aiohttp\nimport html2text\nimport lxml\nimport lxml.etree\nimport lxml.html\nimport pydantic\n\nimport tiktoken\n\nlogger = logging.getLogger(__name__)\n\n\nHTML_SUP_RE = re.compile(r\"<sup( [^>]*)?>([\\w\\-]+)</sup>\")\nHTML_SUB_RE = re.compile(r\"<sub( [^>]*)?>([\\w\\-]+)</sub>\")\nHTML_TAGS_SEQ_RE = re.compile(r\"(?<=\\w)((<[^>]*>)+)(?=\\w)\")\nWHITESPACE_ANCHOR_RE = re.compile(r\"(【\\@[^】]+】)(\\s+)\")\nEMPTY_LINE_RE = re.compile(r\"^\\s+$\", flags=re.MULTILINE)\nEXTRA_NEWLINE_RE = re.compile(r\"\\n(\\s*\\n)+\")\n\n\nclass Extract(pydantic.BaseModel):  # A search result snippet or a quotable extract\n    url: str\n    text: str\n    title: str\n    line_idx: int | None = None\n\n\nclass FetchResult(pydantic.BaseModel):\n    url: str\n    success: bool\n    title: str | None = None\n    error_type: str | None = None\n    error_message: str | None = None\n    html: str | None = None\n    raw_content: bytes | None = None\n    plaintext: str | None = None\n\n\nclass PageContents(pydantic.BaseModel):\n    url: str\n    text: str\n    title: str\n    urls: dict[str, str]\n    snippets: dict[str, Extract] | None = None\n    error_message: str | None = None\n\n\n@dataclasses.dataclass(frozen=True)\nclass Tokens:\n    tokens: list[int]\n    tok2idx: list[int]  # Offsets = running sum of lengths.\n\n\ndef get_domain(url: str) -> str:\n    \"\"\"Extracts the domain from a URL.\"\"\"\n    if \"http\" not in url:\n        # If `get_domain` is called on a domain, add a scheme so that the\n        # original domain is returned instead of the empty string.\n        url = \"http://\" + url\n    return urlparse(url).netloc\n\n\ndef multiple_replace(text: str, replacements: dict[str, str]) -> str:\n    \"\"\"Performs multiple string replacements using regex pass.\"\"\"\n    regex = re.compile(\"(%s)\" % \"|\".join(map(re.escape, replacements.keys())))\n    return regex.sub(lambda mo: replacements[mo.group(1)], text)\n\n\n@functools.lru_cache(maxsize=1024)\ndef mark_lines(text: str) -> str:\n    \"\"\"Adds line numbers (ex: 'L0:') to the beginning of each line in a string.\"\"\"\n    # Split the string by newline characters\n    lines = text.split(\"\\n\")\n\n    # Add lines numbers to each line and join into a single string\n    numbered_text = \"\\n\".join([f\"L{i}: {line}\" for i, line in enumerate(lines)])\n    return numbered_text\n\n\n@functools.cache\ndef _tiktoken_vocabulary_lengths(enc_name: str) -> list[int]:\n    \"\"\"Gets the character lengths of all tokens in the specified TikToken vocabulary.\"\"\"\n    encoding = tiktoken.get_encoding(enc_name)\n    return [len(encoding.decode([i])) for i in range(encoding.n_vocab)]\n\n\ndef warmup_caches(enc_names: list[str]) -> None:\n    \"\"\"Warm up the cache by computing token length lists for the given TikToken encodings.\"\"\"\n    for _ in map(_tiktoken_vocabulary_lengths, enc_names):\n        pass\n\n\ndef _replace_special_chars(text: str) -> str:\n    \"\"\"Replaces specific special characters with visually similar alternatives.\"\"\"\n    replacements = {\n        \"【\": \"〖\",\n        \"】\": \"〗\",\n        \"◼\": \"◾\",\n        # \"━\": \"─\",\n        \"\\u200b\": \"\",  # zero width space\n        # Note: not replacing †\n    }\n    return multiple_replace(text, replacements)\n\n\ndef merge_whitespace(text: str) -> str:\n    \"\"\"Replace newlines with spaces and merge consecutive whitespace into a single space.\"\"\"\n    text = text.replace(\"\\n\", \" \")\n    text = re.sub(r\"\\s+\", \" \", text)\n    return text\n\n\ndef arxiv_to_ar5iv(url: str) -> str:\n    \"\"\"Converts an arxiv.org URL to its ar5iv.org equivalent.\"\"\"\n    return re.sub(r\"arxiv.org\", r\"ar5iv.org\", url)\n\n\ndef _clean_links(root: lxml.html.HtmlElement, cur_url: str) -> dict[str, str]:\n    \"\"\"Processes all anchor tags in the HTML, replaces them with a custom format and returns an ID-to-URL mapping.\"\"\"\n    cur_domain = get_domain(cur_url)\n    urls: dict[str, str] = {}\n    urls_rev: dict[str, str] = {}\n    for a in root.findall(\".//a[@href]\"):\n        assert a.getparent() is not None\n        link = a.attrib[\"href\"]\n        if link.startswith((\"mailto:\", \"javascript:\")):\n            continue\n        text = _get_text(a).replace(\"†\", \"‡\")\n        if not re.sub(r\"【\\@([^】]+)】\", \"\", text):  # Probably an image\n            continue\n        if link.startswith(\"#\"):\n            replace_node_with_text(a, text)\n            continue\n        try:\n            link = urljoin(cur_url, link)  # works with both absolute and relative links\n            domain = get_domain(link)\n        except Exception:\n            domain = \"\"\n        if not domain:\n            logger.debug(\"SKIPPING LINK WITH URL %s\", link)\n            continue\n        link = arxiv_to_ar5iv(link)\n        if (link_id := urls_rev.get(link)) is None:\n            link_id = f\"{len(urls)}\"\n            urls[link_id] = link\n            urls_rev[link] = link_id\n        if domain == cur_domain:\n            replacement = f\"【{link_id}†{text}】\"\n        else:\n            replacement = f\"【{link_id}†{text}†{domain}】\"\n        replace_node_with_text(a, replacement)\n    return urls\n\n\ndef _get_text(node: lxml.html.HtmlElement) -> str:\n    \"\"\"Extracts all text from an HTML element and merges it into a whitespace-normalized string.\"\"\"\n    return merge_whitespace(\" \".join(node.itertext()))\n\n\ndef _remove_node(node: lxml.html.HtmlElement) -> None:\n    \"\"\"Removes a node from its parent in the lxml tree.\"\"\"\n    node.getparent().remove(node)\n\n\ndef _escape_md(text: str) -> str:\n    return text\n\n\ndef _escape_md_section(text: str, snob: bool = False) -> str:\n    return text\n\n\ndef html_to_text(html: str) -> str:\n    \"\"\"Converts an HTML string to clean plaintext.\"\"\"\n    html = re.sub(HTML_SUP_RE, r\"^{\\2}\", html)\n    html = re.sub(HTML_SUB_RE, r\"_{\\2}\", html)\n    # add spaces between tags such as table cells\n    html = re.sub(HTML_TAGS_SEQ_RE, r\" \\1\", html)\n    # we don't need to escape markdown, so monkey-patch the logic\n    orig_escape_md = html2text.utils.escape_md\n    orig_escape_md_section = html2text.utils.escape_md_section\n    html2text.utils.escape_md = _escape_md\n    html2text.utils.escape_md_section = _escape_md_section\n    h = html2text.HTML2Text()\n    h.ignore_links = True\n    h.ignore_images = True\n    h.body_width = 0  # no wrapping\n    h.ignore_tables = True\n    h.unicode_snob = True\n    h.ignore_emphasis = True\n    result = h.handle(html).strip()\n    html2text.utils.escape_md = orig_escape_md\n    html2text.utils.escape_md_section = orig_escape_md_section\n    return result\n\n\ndef _remove_math(root: lxml.html.HtmlElement) -> None:\n    \"\"\"Removes all <math> elements from the lxml tree.\"\"\"\n    for node in root.findall(\".//math\"):\n        _remove_node(node)\n\n\ndef remove_unicode_smp(text: str) -> str:\n    \"\"\"Removes Unicode characters in the Supplemental Multilingual Plane (SMP) from `text`.\n\n    SMP characters are not supported by lxml.html processing.\n    \"\"\"\n    smp_pattern = re.compile(r\"[\\U00010000-\\U0001FFFF]\", re.UNICODE)\n    return smp_pattern.sub(\"\", text)\n\n\ndef replace_node_with_text(node: lxml.html.HtmlElement, text: str) -> None:\n    \"\"\"Replaces an lxml node with a text string while preserving surrounding text.\"\"\"\n    previous = node.getprevious()\n    parent = node.getparent()\n    tail = node.tail or \"\"\n    if previous is None:\n        parent.text = (parent.text or \"\") + text + tail\n    else:\n        previous.tail = (previous.tail or \"\") + text + tail\n    parent.remove(node)\n\n\ndef replace_images(\n    root: lxml.html.HtmlElement,\n    base_url: str,\n    session: aiohttp.ClientSession | None,\n) -> None:\n    \"\"\"Finds all image tags and replaces them with numbered placeholders (includes alt/title if available).\"\"\"\n    cnt = 0\n    for img_tag in root.findall(\".//img\"):\n        image_name = img_tag.get(\"alt\", img_tag.get(\"title\"))\n        if image_name:\n            replacement = f\"[Image {cnt}: {image_name}]\"\n        else:\n            replacement = f\"[Image {cnt}]\"\n        replace_node_with_text(img_tag, replacement)\n        cnt += 1\n\n\ndef process_html(\n    html: str,\n    url: str,\n    title: str | None,\n    session: aiohttp.ClientSession | None = None,\n    display_urls: bool = False,\n) -> PageContents:\n    \"\"\"Convert HTML into model-readable version.\"\"\"\n    html = remove_unicode_smp(html)\n    html = _replace_special_chars(html)\n    root = lxml.html.fromstring(html)\n\n    # Parse the title.\n    title_element = root.find(\".//title\")\n    if title:\n        final_title = title\n    elif title_element is not None:\n        final_title = title_element.text or \"\"\n    elif url and (domain := get_domain(url)):\n        final_title = domain\n    else:\n        final_title = \"\"\n\n    urls = _clean_links(root, url)\n    replace_images(\n        root=root,\n        base_url=url,\n        session=session,\n    )\n    _remove_math(root)\n    clean_html = lxml.etree.tostring(root, encoding=\"UTF-8\").decode()\n    text = html_to_text(clean_html)\n    text = re.sub(WHITESPACE_ANCHOR_RE, lambda m: m.group(2) + m.group(1), text)\n    # ^^^ move anchors to the right thru whitespace\n    # This way anchors don't create extra whitespace\n    text = re.sub(EMPTY_LINE_RE, \"\", text)\n    # ^^^ Get rid of empty lines\n    text = re.sub(EXTRA_NEWLINE_RE, \"\\n\\n\", text)\n    # ^^^ Get rid of extra newlines\n\n    top_parts = []\n    if display_urls:\n        top_parts.append(f\"\\nURL: {url}\\n\")\n    # NOTE: Publication date is currently not extracted due\n    # to performance costs.\n\n    return PageContents(\n        url=url,\n        text=\"\".join(top_parts) + text,\n        urls=urls,\n        title=final_title,\n    )\n"
  },
  {
    "path": "gpt_oss/tools/simple_browser/simple_browser_tool.py",
    "content": "import contextvars\nimport dataclasses\nimport functools\nimport itertools\nimport json\nimport re\nimport textwrap\nfrom typing import Any, AsyncIterator, Callable, ParamSpec, Sequence\nfrom urllib.parse import quote, unquote\n\nimport pydantic\nimport structlog\nimport tiktoken\nfrom aiohttp import ClientSession\nfrom openai_harmony import (\n    Author,\n    Content,\n    Message,\n    Role,\n    TextContent,\n    ToolNamespaceConfig\n)\n\nfrom ..tool import Tool\n\n# from functions import Function, from_python\nfrom .backend import (\n    VIEW_SOURCE_PREFIX,\n    Backend,\n    BackendError,\n    maybe_truncate,\n)\nfrom .page_contents import Extract, PageContents\n\nlogger = structlog.stdlib.get_logger(component=__name__)\n\n\n# TODO(zhuohan): Use the correct encoding at release\nENC_NAME = \"o200k_base\"\nFIND_PAGE_LINK_FORMAT = \"# 【{idx}†{title}】\"\nPARTIAL_INITIAL_LINK_PATTERN = re.compile(r\"^[^【】]*】\")\nPARTIAL_FINAL_LINK_PATTERN = re.compile(\n    r\"【\\d*(?:†(?P<content>[^†】]*)(?:†[^†】]*)?)?$\"\n)\nLINK_PATTERN = re.compile(r\"【\\d+†(?P<content>[^†】]+)(?:†[^†】]+)?】\")\n\nCITATION_OUTPUT_PATTERN = re.compile(r\"【(?P<cursor>\\d+)†(?P<content>[^†】]+)(?:†[^†】]+)?】\")\n\nCallParams = ParamSpec(\"CallParams\")\n\n\n_P = ParamSpec(\"_P\")\n_live_function_name = contextvars.ContextVar[str](\"_live_function_name\")\n\n\nclass ToolUsageError(Exception):\n    pass\n\n\ndef function_the_model_can_call(\n    fn: Callable[_P, AsyncIterator[Message]],\n) -> Callable[_P, AsyncIterator[Message]]:\n    fn.__fn_calling_tool_fn_type__ = \"function_the_model_can_call\"  # type: ignore\n\n    @functools.wraps(fn)\n    async def inner(*args: _P.args, **kwargs: _P.kwargs) -> AsyncIterator[Message]:\n        token = _live_function_name.set(fn.__name__)\n        try:\n            async for m in fn(*args, **kwargs):\n                yield m\n        finally:\n            _live_function_name.reset(token)\n\n    return inner\n\n\n@functools.cache\ndef _tiktoken_vocabulary_lengths(enc_name: str) -> list[int]:\n    encoding = tiktoken.get_encoding(enc_name)\n    results = []\n    for i in range(encoding.n_vocab):\n        try:\n            results.append(len(encoding.decode([i])))\n        except Exception as e:\n            results.append(1)\n    return results\n\n\n@dataclasses.dataclass(frozen=True)\nclass Tokens:\n    tokens: list[int]\n    tok2idx: list[int]  # Offsets = running sum of lengths.\n\n\n@functools.cache\ndef max_chars_per_token(enc_name: str) -> int:\n    \"\"\"Typical value is 128, but let's be safe.\"\"\"\n    tok_lens = _tiktoken_vocabulary_lengths(enc_name)\n    return max(tok_lens)\n\n\ndef get_tokens(text: str, enc_name: str) -> Tokens:\n    encoding = tiktoken.get_encoding(enc_name)\n    tokens = encoding.encode(text, disallowed_special=())\n    _vocabulary_lengths = _tiktoken_vocabulary_lengths(enc_name)\n    tok2idx = [0] + list(itertools.accumulate(_vocabulary_lengths[i] for i in tokens))[\n        :-1\n    ]\n    result = Tokens(tokens=tokens, tok2idx=tok2idx)\n    return result\n\n\ndef get_end_loc(\n    loc: int,\n    num_lines: int,\n    total_lines: int,\n    lines: list[str],\n    view_tokens: int,\n    encoding_name: str,\n) -> int:\n    if num_lines <= 0:\n        # COMPUTE NUMBER OF LINES TO SHOW\n        txt = join_lines(lines[loc:], add_line_numbers=True, offset=loc)\n        # if the text is very short, no need to truncate at all\n        # at least one char per token\n        if len(txt) > view_tokens:\n            # limit the amount of text we tokenize here\n            upper_bound = max_chars_per_token(encoding_name)\n            tok2idx = get_tokens(\n                txt[: (view_tokens + 1) * upper_bound], encoding_name\n            ).tok2idx\n            if len(tok2idx) > view_tokens:\n                end_idx = tok2idx[view_tokens]\n                num_lines = txt[:end_idx].count(\"\\n\") + 1  # round up\n            else:\n                num_lines = total_lines\n        else:\n            num_lines = total_lines\n\n    return min(loc + num_lines, total_lines)\n\n\ndef get_page_metadata(\n    curr_page: PageContents,\n) -> dict[str, str | None | dict[str, str] | list[str]]:\n    \"\"\"Some attributes of the current page.\"\"\"\n    page_metadata: dict[str, str | None | dict[str, str] | list[str]] = {\n        \"url\": curr_page.url,\n        \"title\": curr_page.title,\n    }\n    return page_metadata\n\n\ndef join_lines(\n    lines: list[str], add_line_numbers: bool = False, offset: int = 0\n) -> str:\n    if add_line_numbers:\n        return \"\\n\".join([f\"L{i + offset}: {line}\" for i, line in enumerate(lines)])\n    else:\n        return \"\\n\".join(lines)\n\n\ndef wrap_lines(text: str, width: int = 80) -> list[str]:\n    lines = text.split(\"\\n\")\n    wrapped = itertools.chain.from_iterable(\n        (\n            textwrap.wrap(\n                line, width=width, replace_whitespace=False, drop_whitespace=False\n            )\n            if line\n            else [\"\"]\n        )  # preserve empty lines\n        for line in lines\n    )\n    return list(wrapped)\n\n\ndef strip_links(text: str) -> str:\n    text = re.sub(PARTIAL_INITIAL_LINK_PATTERN, \"\", text)\n    text = re.sub(PARTIAL_FINAL_LINK_PATTERN, lambda mo: mo.group(\"content\"), text)\n    text = re.sub(LINK_PATTERN, lambda mo: mo.group(\"content\"), text)\n    return text\n\n\ndef maybe_get_function_args(\n    message: Message, tool_name: str = \"browser\"\n) -> dict[str, Any] | None:\n    if not message.recipient.startswith(f\"{tool_name}.\"):\n        return None\n\n    contents = \"\"\n    if len(message.content) == 1 and isinstance(message.content[0], TextContent):\n        contents = message.content[0].text\n\n    if not contents:\n        return {}\n\n    try:\n        parsed_contents = json.loads(contents)\n        if isinstance(parsed_contents, dict):\n            return parsed_contents\n    except json.JSONDecodeError:\n        pass\n\n    return None\n\n\nasync def run_find_in_page(\n    pattern: str,\n    page: PageContents,\n    max_results: int = 50,\n    num_show_lines: int = 4,\n) -> PageContents:\n    lines = wrap_lines(text=page.text)\n    txt = join_lines(lines, add_line_numbers=False)\n    without_links = strip_links(txt)\n    lines = without_links.split(\"\\n\")\n\n    result_chunks, snippets = [], []\n    line_idx, match_idx = 0, 0\n    while line_idx < len(lines):\n        line = lines[line_idx]\n        if pattern not in line.lower():\n            line_idx += 1\n            continue\n        snippet = \"\\n\".join(lines[line_idx : line_idx + num_show_lines])\n        link_title = FIND_PAGE_LINK_FORMAT.format(\n            idx=f\"{match_idx}\", title=f\"match at L{line_idx}\"\n        )\n        result_chunks.append(f\"{link_title}\\n{snippet}\")\n        snippets.append(\n            Extract(\n                url=page.url, text=snippet, title=f\"#{match_idx}\", line_idx=line_idx\n            )\n        )\n        if len(result_chunks) == max_results:\n            break\n        match_idx += 1\n        line_idx += num_show_lines\n\n    urls = [page.url for _ in result_chunks]\n\n    if result_chunks:\n        display_text = \"\\n\\n\".join(result_chunks)\n    else:\n        display_text = f\"No `find` results for pattern: `{pattern}`\"\n\n    result_page = PageContents(\n        url=f\"{page.url}/find?pattern={quote(pattern)}\",\n        title=f\"Find results for text: `{pattern}` in `{page.title}`\",\n        text=display_text,\n        urls={str(i): url for i, url in enumerate(urls)},\n        snippets={str(i): snip for i, snip in enumerate(snippets)},\n    )\n    return result_page\n\n\ndef handle_errors(\n    func: Callable[CallParams, AsyncIterator[\"Message\"]],\n) -> Callable[CallParams, AsyncIterator[\"Message\"]]:\n    @functools.wraps(func)\n    async def inner(\n        *args: CallParams.args, **kwargs: CallParams.kwargs\n    ) -> AsyncIterator[Message]:\n        tool = args[0]\n        # Could be cool to type this explicitly, but mypy makes it hard\n        assert isinstance(tool, SimpleBrowserTool)\n        try:\n            async for msg in func(*args, **kwargs):\n                yield msg\n        except (ToolUsageError, BackendError) as e:\n            yield tool.make_error_message(e)\n\n    return inner\n\n\nclass SimpleBrowserState(pydantic.BaseModel):\n    # maps page url to page contents\n    pages: dict[str, PageContents] = pydantic.Field(default_factory=dict)\n    # a sequential list of page urls\n    page_stack: list[str] = pydantic.Field(default_factory=list)\n\n    @property\n    def current_cursor(self) -> int:\n        return len(self.page_stack) - 1\n\n    def add_page(self, page: PageContents) -> None:\n        self.pages[page.url] = page\n        self.page_stack.append(page.url)\n\n    def get_page(self, cursor: int = -1) -> PageContents:\n        if self.current_cursor < 0:\n            raise ToolUsageError(\"No pages to access!\")\n        if cursor == -1 or cursor == self.current_cursor:\n            return self.pages[self.page_stack[-1]]\n        try:\n            page_url = self.page_stack[cursor]\n        except TypeError as e:\n            raise ToolUsageError(\n                f\"`cursor` should be an integer, not `{type(cursor).__name__}`\"\n            ) from e\n        except IndexError as e:\n            raise ToolUsageError(\n                f\"Cursor `{cursor}` is out of range. \"\n                f\"Available cursor indices: [0 - {self.current_cursor}].\"\n            ) from e\n        return self.pages[page_url]\n\n    def get_page_by_url(self, url: str) -> PageContents | None:\n        if url in self.pages:\n            return self.pages[url]\n        return None\n\n    def pop_page_stack(self) -> None:\n        assert self.current_cursor >= 0, \"No page to pop!\"\n        self.page_stack.pop()\n\n\nclass SimpleBrowserTool(Tool):\n    def __init__(\n        self,\n        backend: Backend,\n        encoding_name: str = ENC_NAME,\n        max_search_results: int = 20,\n        tool_state: dict[str, Any] | None = None,\n        view_tokens: int = 1024,\n        name: str = \"browser\",\n    ):\n        assert name == \"browser\"\n        self.backend = backend\n        if tool_state is None:\n            self.tool_state = SimpleBrowserState()\n        else:\n            self.tool_state = SimpleBrowserState.model_validate(tool_state)\n\n        self.encoding_name = encoding_name\n        self.max_search_results = max_search_results\n        self.view_tokens = view_tokens\n\n    def get_tool_state(self) -> dict[str, Any]:\n        return {\"tool_state\": self.tool_state.model_dump()}\n\n    @classmethod\n    def get_tool_name(cls) -> str:\n        return \"browser\"\n\n    @property\n    def name(self) -> str:\n        return self.get_tool_name()\n\n    @property\n    def tool_config(self) -> ToolNamespaceConfig:\n        config = ToolNamespaceConfig.browser()\n        config.name = self.name\n        config.description = \"\"\"Tool for browsing.\nThe `cursor` appears in brackets before each browsing display: `[{cursor}]`.\nCite information from the tool using the following format:\n`【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\nDo not quote more than 10 words directly from the tool output.\nsources=\"\"\" + self.backend.source\n        return config\n\n    @property\n    def instruction(self) -> str:\n        return self.tool_config.description\n\n    def _render_browsing_display(\n        self,\n        tether_id: int,\n        result: str,\n        summary: str | None = None,\n    ):\n        to_return = \"\"\n        # Always show summaries.\n        if summary:\n            to_return += summary\n        to_return += result\n        to_return = f\"[{tether_id}] {to_return}\"\n        return to_return\n\n    def _make_response(\n        self,\n        page: PageContents,\n        cursor: int,\n        body: str,\n        scrollbar: str,\n    ) -> Message:\n        domain = maybe_truncate(unquote(page.url))\n        header = f\"{page.title}\"\n        if domain:\n            header += f\" ({domain})\"\n        header += f\"\\n**{scrollbar}**\\n\\n\"\n\n        content = TextContent(text=self._render_browsing_display(cursor, body, header))\n        return self.make_response(\n            content=content, metadata=get_page_metadata(self.tool_state.get_page())\n        )\n\n    async def show_page(self, loc: int = 0, num_lines: int = -1) -> Message:\n        page = self.tool_state.get_page()\n        cursor = self.tool_state.current_cursor\n        lines = wrap_lines(text=page.text)\n        total_lines = len(lines)\n\n        if loc >= total_lines:\n            err_msg = (\n                f\"Invalid location parameter: `{loc}`. \"\n                f\"Cannot exceed page maximum of {total_lines - 1}.\"\n            )\n            raise ToolUsageError(err_msg)\n\n        end_loc = get_end_loc(\n            loc, num_lines, total_lines, lines, self.view_tokens, self.encoding_name\n        )\n\n        lines_to_show = lines[loc:end_loc]\n        body = join_lines(lines_to_show, add_line_numbers=True, offset=loc)\n\n        scrollbar = f\"viewing lines [{loc} - {end_loc - 1}] of {total_lines - 1}\"\n        return self._make_response(page, cursor, body, scrollbar)\n\n    async def show_page_safely(self, loc: int = 0, num_lines: int = -1) -> Message:\n        try:\n            return await self.show_page(loc=loc, num_lines=num_lines)\n        except ToolUsageError as e:\n            self.tool_state.pop_page_stack()\n            raise e\n\n    async def _open_url(self, url: str, direct_url_open: bool) -> PageContents:\n        \"\"\"Use the cache, if available.\"\"\"\n        backend = self.backend\n        # direct_url_open should be regarded as a refresh\n        if not direct_url_open and (page := self.tool_state.get_page_by_url(url)):\n            assert page.url == url\n            return page\n\n        try:\n            async with ClientSession() as session:\n                page = await backend.fetch(url, session=session)\n            return page\n        except Exception as e:\n            msg = maybe_truncate(str(e))\n            logger.warning(\"Error fetching URL in lean browser tool\", exc_info=e)\n            raise BackendError(\n                f\"Error fetching URL `{maybe_truncate(url)}`: {msg}\"\n            ) from e\n\n    def make_error_message(self, error: Exception) -> Message:\n        \"\"\"Uses the message creation codepath from the base class.\"\"\"\n        error_name = error.__class__.__name__\n        content = TextContent(text=str(error))\n        return self.make_response(content=content)\n\n    @function_the_model_can_call\n    @handle_errors\n    async def search(\n        self,\n        query: str,\n        topn: int = 10,\n        top_n: int = 10,\n        source: str | None = None,\n    ) -> AsyncIterator[Message]:\n        del topn\n        del top_n\n        try:\n            async with ClientSession() as session:\n                search_page = await self.backend.search(\n                    query=query,\n                    topn=self.max_search_results,\n                    session=session,\n                )\n        except Exception as e:\n            msg = maybe_truncate(str(e))\n            raise BackendError(f\"Error during search for `{query}`: {msg}\") from e\n\n        self.tool_state.add_page(search_page)\n        yield await self.show_page_safely(loc=0)\n\n    @function_the_model_can_call\n    @handle_errors\n    async def open(\n        self,\n        id: int | str = -1,\n        cursor: int = -1,\n        loc: int = -1,\n        num_lines: int = -1,\n        view_source: bool = False,\n        source: str | None = None,\n    ) -> AsyncIterator[Message]:\n        curr_page: PageContents | None = None\n        stay_on_current_page = False\n        direct_url_open = False\n        if isinstance(id, str):\n            snippet = None\n            url = id\n            direct_url_open = True\n        else:  # Operate on a previously opened page\n            curr_page = self.tool_state.get_page(cursor)\n\n            if id >= 0:  # click a link\n                try:\n                    url = curr_page.urls[str(id)]\n                except KeyError as e:\n                    raise ToolUsageError(f\"Invalid link id `{id}`.\") from e\n                snippet = (curr_page.snippets or {}).get(str(id))\n                if snippet and curr_page.url == \"\":\n                    # current page is a search result page\n                    assert isinstance(snippet, Extract)\n            else:  # navigate to new position on the current page\n                if not view_source:\n                    stay_on_current_page = True\n                url = curr_page.url\n                snippet = None\n\n        new_page: PageContents\n        if view_source:\n            url = f\"{VIEW_SOURCE_PREFIX}{url}\"\n            snippet = None\n        if stay_on_current_page:\n            assert curr_page is not None\n            new_page = curr_page\n        else:\n            new_page = await self._open_url(url, direct_url_open)\n\n        self.tool_state.add_page(new_page)\n\n        if loc < 0:  # unset\n            if snippet is not None and snippet.line_idx is not None:\n                loc = snippet.line_idx\n                if loc > 4:\n                    loc -= 4\n            else:\n                loc = 0\n        yield await self.show_page_safely(loc=loc, num_lines=num_lines)\n\n    @function_the_model_can_call\n    @handle_errors\n    async def find(self, pattern: str, cursor: int = -1) -> AsyncIterator[Message]:\n        page = self.tool_state.get_page(cursor)\n        if page.snippets is not None:\n            raise ToolUsageError(\n                \"Cannot run `find` on search results page or find results page\"\n            )\n\n        pc = await run_find_in_page(\n            pattern=str(pattern).lower(),\n            page=page,\n        )\n        self.tool_state.add_page(pc)\n        yield await self.show_page_safely(loc=0)\n\n    def make_response(\n        self,\n        content: Content,\n        *,\n        metadata: dict[str, Any] | None = None,\n        author: Author | None = None,\n    ) -> Message:\n        \"\"\"\n        Make a response message.\n\n        Should be used from `@function_the_model_can_call` if author is not provided.\n        \"\"\"\n        if author is None:\n            tool_name = self.get_tool_name()\n            function_name = _live_function_name.get()\n            assert function_name is not None\n            author = Author(role=Role.TOOL, name=f\"{tool_name}.{function_name}\")\n\n        return Message(\n            author=author,\n            content=[content],\n        ).with_recipient(\"assistant\")\n\n    def process_arguments(self, message: Message) -> dict[str, Any]:\n        function_args = maybe_get_function_args(message, tool_name=self.name)\n        if function_args is None:\n            raise ValueError(\"Invalid function arguments\")\n\n        if \"cursor\" in function_args and function_args[\"cursor\"] >= 0:\n            page = self.tool_state.get_page(cursor=function_args[\"cursor\"])\n            if \"id\" in function_args:\n                function_args[\"url\"] = page.urls[str(function_args[\"id\"])]\n            else:\n                function_args[\"url\"] = page.url\n        elif \"id\" in function_args and isinstance(function_args[\"id\"], str):\n            function_args[\"url\"] = function_args[\"id\"]\n        return function_args\n\n    async def _process(self, message: Message) -> AsyncIterator[Message]:\n        def make_error_message(error: str) -> Message:\n            return self.make_response(\n                content=TextContent(text=json.dumps({\"error\": error})),\n                author=Author(role=Role.TOOL, name=message.recipient),\n            )\n\n        function_args = maybe_get_function_args(message, tool_name=self.name)\n        if function_args is None:\n            yield make_error_message(\"Invalid function arguments\")\n            return\n\n        _, function_name = message.recipient.split(\".\")\n        if function_name not in [\"search\", \"open\", \"find\"]:\n            yield make_error_message(f\"Unknown function: {function_name}\")\n            return\n\n        if function_name == \"search\":\n            async for msg in self.search(**function_args):\n                yield msg\n        elif function_name == \"open\":\n            async for msg in self.open(**function_args):\n                yield msg\n        elif function_name == \"find\":\n            async for msg in self.find(**function_args):\n                yield msg\n        else:\n            raise ValueError(\"should not be here\")\n\n\n    def normalize_citations(self, old_content: str, hide_partial_citations: bool = False) -> tuple[str, list[dict[str, Any]], bool]:\n        \"\"\"\n        Returns a tuple of (new_message, annotations, has_partial_citations)\n        - new_message: Message with citations replaced by ([domain](url))\n        - annotations: list of dicts with start_index, end_index, and title (url)\n        - has_partial_citations: whether the text includes an unfinished citation\n        \"\"\"\n\n        has_partial_citations = PARTIAL_FINAL_LINK_PATTERN.search(old_content) is not None\n        if hide_partial_citations and has_partial_citations:\n            old_content = PARTIAL_FINAL_LINK_PATTERN.sub(\"\", old_content)\n\n        matches = []\n        for match in CITATION_OUTPUT_PATTERN.finditer(old_content):\n            cursor = match.group(\"cursor\")\n            content = match.group(\"content\")\n            start_idx = match.start()\n            end_idx = match.end()\n            matches.append({\n                \"cursor\": cursor,\n                \"content\": content,\n                \"start\": start_idx,\n                \"end\": end_idx\n            })\n\n        # Build a mapping from cursor to url\n        cursor_to_url = {}\n        for idx, url in enumerate(self.tool_state.page_stack):\n            cursor_to_url[str(idx)] = url\n\n        def extract_domain(url):\n            try:\n                return unquote(url).split(\"/\")[2]\n            except Exception:\n                return url\n\n        new_content = \"\"\n        last_idx = 0\n        annotations = []\n        running_offset = 0  # Offset due to length changes in replacements\n\n        for m in matches:\n            cursor = m[\"cursor\"]\n            url = cursor_to_url.get(cursor, None)\n            orig_start = m[\"start\"]\n            orig_end = m[\"end\"]\n\n            # Add text before the citation\n            new_content += old_content[last_idx:orig_start]\n\n            if url:\n                domain = extract_domain(url)\n                replacement = f\" ([{domain}]({url})) \"\n                # The start and end indices in the new content\n                start_index = len(new_content)\n                end_index = start_index + len(replacement)\n                annotations.append({\n                    \"start_index\": start_index,\n                    \"end_index\": end_index,\n                    \"title\": domain,\n                    \"url\": url,\n                    \"type\": \"url_citation\",\n                })\n                new_content += replacement\n            else:\n                # Keep the original citation format if cursor is missing\n                replacement = old_content[orig_start:orig_end]\n                start_index = len(new_content)\n                end_index = start_index + len(replacement)\n                # No annotation for missing url, but could add if desired\n                new_content += replacement\n\n            last_idx = orig_end\n\n        new_content += old_content[last_idx:]\n        return new_content, annotations, has_partial_citations\n\n"
  },
  {
    "path": "gpt_oss/tools/tool.py",
    "content": "from abc import ABC, abstractmethod\nfrom uuid import UUID, uuid4\nfrom typing import AsyncIterator\n\nfrom openai_harmony import (\n    Author,\n    Role,\n    Message,\n    TextContent,\n)\n\n\ndef _maybe_update_inplace_and_validate_channel(\n    *, input_message: Message, tool_message: Message\n) -> None:\n    # If the channel of a new message produced by tool is different from the originating message,\n    # we auto-set the new message's channel, if unset, or raise an error.\n    if tool_message.channel != input_message.channel:\n        if tool_message.channel is None:\n            tool_message.channel = input_message.channel\n        else:\n            raise ValueError(\n                f\"Messages from tool should have the same channel ({tool_message.channel=}) as \"\n                f\"the triggering message ({input_message.channel=}).\"\n            )\n\n\nclass Tool(ABC):\n    \"\"\"\n    Something the model can call.\n\n    Tools expose APIs that are shown to the model in a syntax that the model\n    understands and knows how to call (from training data). Tools allow the\n    model to do things like run code, browse the web, etc.\n    \"\"\"\n\n    @property\n    @abstractmethod\n    def name(self) -> str:\n        \"\"\"\n        An identifier for the tool. The convention is that a message will be routed to the tool\n        whose name matches its recipient field.\n        \"\"\"\n\n    @property\n    def output_channel_should_match_input_channel(self) -> bool:\n        \"\"\"\n        A flag which indicates whether the output channel of the tool should match the input channel.\n        \"\"\"\n        return True\n\n    async def process(self, message: Message) -> AsyncIterator[Message]:\n        \"\"\"\n        Process the message and return a list of messages to add to the conversation.\n        The input message should already be applicable to this tool.\n        Don't return the input message, just the new messages.\n\n        If implementing a tool that has to block while calling a function use `call_on_background_thread` to get a coroutine.\n\n        If you just want to test this use `evaluate_generator` to get the results.\n\n        Do not override this method; override `_process` below (to avoid interfering with tracing).\n        \"\"\"\n        async for m in self._process(message):\n            if self.output_channel_should_match_input_channel:\n                _maybe_update_inplace_and_validate_channel(input_message=message, tool_message=m)\n            yield m\n\n    @abstractmethod\n    async def _process(self, message: Message) -> AsyncIterator[Message]:\n        \"\"\"Override this method to provide the implementation of the tool.\"\"\"\n        if False:  # This is to convince the type checker that this is an async generator.\n            yield  # type: ignore[unreachable]\n        _ = message  # Stifle \"unused argument\" warning.\n        raise NotImplementedError\n\n    @abstractmethod\n    def instruction(self) -> str:\n        \"\"\"\n        Describe the tool's functionality. For example, if it accepts python-formatted code,\n        provide documentation on the functions available.\n        \"\"\"\n        raise NotImplementedError\n\n    def instruction_dict(self) -> dict[str, str]:\n        return {self.name: self.instruction()}\n\n    def error_message(\n        self, error_message: str, id: UUID | None = None, channel: str | None = None\n    ) -> Message:\n        \"\"\"\n        Return an error message that's from this tool.\n        \"\"\"\n        return Message(\n            id=id if id else uuid4(),\n            author=Author(role=Role.TOOL, name=self.name),\n            content=TextContent(text=error_message), # TODO: Use SystemError instead\n            channel=channel,\n        ).with_recipient(\"assistant\")\n\n"
  },
  {
    "path": "gpt_oss/torch/__init__.py",
    "content": ""
  },
  {
    "path": "gpt_oss/torch/model.py",
    "content": "import json\nimport math\nimport os\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.distributed as dist\n\nfrom gpt_oss.torch.weights import Checkpoint\n\n\n@dataclass\nclass ModelConfig:\n    num_hidden_layers: int = 36\n    num_experts: int = 128\n    experts_per_token: int = 4\n    vocab_size: int = 201088\n    hidden_size: int = 2880\n    intermediate_size: int = 2880\n    swiglu_limit: float = 7.0\n    head_dim: int = 64\n    num_attention_heads: int = 64\n    num_key_value_heads: int = 8\n    sliding_window: int = 128\n    initial_context_length: int = 4096\n    rope_theta: float = 150000.0\n    rope_scaling_factor: float = 32.0\n    rope_ntk_alpha: float = 1.0\n    rope_ntk_beta: float = 32.0\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(\n        self, num_features: int, eps: float = 1e-05, device: torch.device | None = None\n    ):\n        super().__init__()\n        self.num_features = num_features\n        self.eps = eps\n        self.scale = torch.nn.Parameter(\n            torch.ones(num_features, device=device, dtype=torch.float32)\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        assert x.shape[-1] == self.num_features\n        t, dtype = x.float(), x.dtype\n        t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)\n        return (t * self.scale).to(dtype)\n\n\ndef _apply_rotary_emb(\n    x: torch.Tensor,\n    cos: torch.Tensor,\n    sin: torch.Tensor,\n) -> torch.Tensor:\n    cos = cos.unsqueeze(-2).to(x.dtype)\n    sin = sin.unsqueeze(-2).to(x.dtype)\n    x1, x2 = torch.chunk(x, 2, dim=-1)\n    o1 = x1 * cos - x2 * sin\n    o2 = x2 * cos + x1 * sin\n    return torch.cat((o1, o2), dim=-1)\n\n\nclass RotaryEmbedding(torch.nn.Module):\n    def __init__(\n        self,\n        head_dim: int,\n        base: int,\n        dtype: torch.dtype,\n        initial_context_length: int = 4096,\n        scaling_factor: float = 1.0,\n        ntk_alpha: float = 1.0,\n        ntk_beta: float = 32.0,\n        device: torch.device | None = None,\n    ) -> None:\n        super().__init__()\n        self.head_dim = head_dim\n        self.base = base\n        self.dtype = dtype\n        self.initial_context_length = initial_context_length\n        self.scaling_factor = scaling_factor\n        self.ntk_alpha = ntk_alpha\n        self.ntk_beta = ntk_beta\n        self.device = device\n\n    def _compute_concentration_and_inv_freq(self) -> torch.Tensor:\n        \"\"\"See YaRN paper: https://arxiv.org/abs/2309.00071\"\"\"\n        freq = self.base ** (\n            torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)\n            / self.head_dim\n        )\n        if self.scaling_factor > 1.0:\n            concentration = (\n                0.1 * math.log(self.scaling_factor) + 1.0\n            )  # YaRN concentration\n\n            d_half = self.head_dim / 2\n            # NTK by parts\n            low = (\n                d_half\n                * math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))\n                / math.log(self.base)\n            )\n            high = (\n                d_half\n                * math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))\n                / math.log(self.base)\n            )\n            assert 0 < low < high < d_half - 1\n\n            interpolation = 1.0 / (self.scaling_factor * freq)\n            extrapolation = 1.0 / freq\n\n            ramp = (\n                torch.arange(d_half, dtype=torch.float32, device=freq.device) - low\n            ) / (high - low)\n            mask = 1 - ramp.clamp(0, 1)\n\n            inv_freq = interpolation * (1 - mask) + extrapolation * mask\n        else:\n            concentration = 1.0\n            inv_freq = 1.0 / freq\n\n        return concentration, inv_freq\n\n    def _compute_cos_sin(self, num_tokens: int):\n        concentration, inv_freq = self._compute_concentration_and_inv_freq()\n        t = torch.arange(num_tokens, dtype=torch.float32, device=self.device)\n        freqs = torch.einsum(\"i,j->ij\", t, inv_freq)\n        cos = freqs.cos() * concentration\n        sin = freqs.sin() * concentration\n        return cos, sin\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        num_tokens = query.shape[0]\n        cos, sin = self._compute_cos_sin(num_tokens)\n\n        query_shape = query.shape\n        query = query.view(num_tokens, -1, self.head_dim)\n        query = _apply_rotary_emb(query, cos, sin)\n        query = query.reshape(query_shape)\n\n        key_shape = key.shape\n        key = key.view(num_tokens, -1, self.head_dim)\n        key = _apply_rotary_emb(key, cos, sin)\n        key = key.reshape(key_shape)\n        return query, key\n\n\ndef sdpa(Q, K, V, S, sm_scale, sliding_window=0):\n    # sliding_window == 0 means no sliding window\n    n_tokens, n_heads, q_mult, d_head = Q.shape\n    assert K.shape == (n_tokens, n_heads, d_head)\n    assert V.shape == (n_tokens, n_heads, d_head)\n    K = K[:, :, None, :].expand(-1, -1, q_mult, -1)\n    V = V[:, :, None, :].expand(-1, -1, q_mult, -1)\n    S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)\n    mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float(\"inf\")), diagonal=1)\n    if sliding_window > 0:\n        mask += torch.tril(\n            mask.new_full((n_tokens, n_tokens), -float(\"inf\")), diagonal=-sliding_window\n        )\n    QK = torch.einsum(\"qhmd,khmd->hmqk\", Q, K)\n    QK *= sm_scale\n    QK += mask[None, None, :, :]\n    QK = torch.cat([QK, S], dim=-1)\n    W = torch.softmax(QK, dim=-1)\n    W = W[..., :-1]\n    attn = torch.einsum(\"hmqk,khmd->qhmd\", W, V)\n    return attn.reshape(n_tokens, -1)\n\n\nclass AttentionBlock(torch.nn.Module):\n    def __init__(\n        self,\n        config: ModelConfig,\n        layer_idx: int = 0,\n        device: torch.device | None = None,\n    ):\n        super().__init__()\n        self.head_dim = config.head_dim\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        # Only apply sliding window to every other layer\n        self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0\n        self.sinks = torch.nn.Parameter(\n            torch.empty(config.num_attention_heads, device=device, dtype=torch.bfloat16)\n        )\n        self.norm = RMSNorm(config.hidden_size, device=device)\n        qkv_dim = config.head_dim * (\n            config.num_attention_heads + 2 * config.num_key_value_heads\n        )\n        self.qkv = torch.nn.Linear(\n            config.hidden_size, qkv_dim, device=device, dtype=torch.bfloat16\n        )\n        self.out = torch.nn.Linear(\n            config.head_dim * config.num_attention_heads,\n            config.hidden_size,\n            device=device,\n            dtype=torch.bfloat16,\n        )\n        self.sm_scale = 1 / math.sqrt(config.head_dim)\n        self.rope = RotaryEmbedding(\n            config.head_dim,\n            config.rope_theta,\n            torch.float32,\n            initial_context_length=config.initial_context_length,\n            scaling_factor=config.rope_scaling_factor,\n            ntk_alpha=config.rope_ntk_alpha,\n            ntk_beta=config.rope_ntk_beta,\n            device=device,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        t = self.norm(x)\n        qkv = self.qkv(t)\n        q = qkv[:, : self.num_attention_heads * self.head_dim].contiguous()\n        k = qkv[\n            :,\n            self.num_attention_heads\n            * self.head_dim : (self.num_attention_heads + self.num_key_value_heads)\n            * self.head_dim,\n        ].contiguous()\n        v = qkv[\n            :,\n            (self.num_attention_heads + self.num_key_value_heads)\n            * self.head_dim : (self.num_attention_heads + 2 * self.num_key_value_heads)\n            * self.head_dim,\n        ].contiguous()\n\n        q = q.view(\n            -1,\n            self.num_key_value_heads,\n            self.num_attention_heads // self.num_key_value_heads,\n            self.head_dim,\n        )\n        k = k.view(-1, self.num_key_value_heads, self.head_dim)\n        v = v.view(-1, self.num_key_value_heads, self.head_dim)\n        q, k = self.rope(q, k)\n        t = sdpa(q, k, v, self.sinks, self.sm_scale, self.sliding_window)\n        t = self.out(t)\n        t = x + t\n        return t\n\n\ndef swiglu(x, alpha: float = 1.702, limit: float = 7.0):\n    x_glu, x_linear = x[..., ::2], x[..., 1::2]\n    # Clamp the input values\n    x_glu = x_glu.clamp(min=None, max=limit)\n    x_linear = x_linear.clamp(min=-limit, max=limit)\n    out_glu = x_glu * torch.sigmoid(alpha * x_glu)\n    # Note we add an extra bias of 1 to the linear layer\n    return out_glu * (x_linear + 1)\n\n\nclass MLPBlock(torch.nn.Module):\n    def __init__(\n        self,\n        config: ModelConfig,\n        device: torch.device | None = None,\n    ):\n        super().__init__()\n        self.num_experts = config.num_experts\n        self.experts_per_token = config.experts_per_token\n        self.swiglu_limit = config.swiglu_limit\n        self.world_size = dist.get_world_size() if dist.is_initialized() else 1\n        self.norm = RMSNorm(config.hidden_size, device=device)\n        self.gate = torch.nn.Linear(\n            config.hidden_size, config.num_experts, device=device, dtype=torch.bfloat16\n        )\n        assert config.intermediate_size % self.world_size == 0\n        self.mlp1_weight = torch.nn.Parameter(\n            torch.empty(\n                (\n                    config.num_experts,\n                    config.intermediate_size * 2 // self.world_size,\n                    config.hidden_size,\n                ),\n                device=device,\n                dtype=torch.bfloat16,\n            )\n        )\n        self.mlp1_bias = torch.nn.Parameter(\n            torch.empty(\n                (config.num_experts, config.intermediate_size * 2 // self.world_size),\n                device=device,\n                dtype=torch.bfloat16,\n            )\n        )\n        self.mlp2_weight = torch.nn.Parameter(\n            torch.empty(\n                (\n                    config.num_experts,\n                    config.hidden_size,\n                    config.intermediate_size // self.world_size,\n                ),\n                device=device,\n                dtype=torch.bfloat16,\n            )\n        )\n        self.mlp2_bias = torch.nn.Parameter(\n            torch.empty(\n                (config.num_experts, config.hidden_size),\n                device=device,\n                dtype=torch.bfloat16,\n            )\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        t = self.norm(x)\n        g = self.gate(t)\n        experts = torch.topk(g, k=self.experts_per_token, dim=-1, sorted=True)\n        expert_weights = torch.nn.functional.softmax(experts.values, dim=1)\n        expert_indices = experts.indices\n\n        # MLP #1\n        mlp1_weight = self.mlp1_weight[expert_indices, ...]\n        mlp1_bias = self.mlp1_bias[expert_indices, ...]\n        t = torch.einsum(\"beck,bk->bec\", mlp1_weight, t) + mlp1_bias\n        t = swiglu(t, limit=self.swiglu_limit)\n\n        # MLP #2\n        mlp2_weight = self.mlp2_weight[expert_indices, ...]\n        mlp2_bias = self.mlp2_bias[expert_indices, ...]\n        t = torch.einsum(\"beck,bek->bec\", mlp2_weight, t)\n        if self.world_size > 1:\n            dist.all_reduce(t, op=dist.ReduceOp.SUM)\n        t += mlp2_bias\n\n        # Weighted sum of experts\n        t = torch.einsum(\"bec,be->bc\", t, expert_weights)\n\n        return x + t\n\n\nclass TransformerBlock(torch.nn.Module):\n    def __init__(\n        self,\n        config: ModelConfig,\n        layer_idx: int,\n        device: torch.device | None = None,\n    ):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.attn = AttentionBlock(config, layer_idx, device)\n        self.mlp = MLPBlock(config, device)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.attn(x)\n        x = self.mlp(x)\n        return x\n\n\nclass Transformer(torch.nn.Module):\n    def __init__(\n        self,\n        config: ModelConfig,\n        device: torch.device | None = None,\n    ):\n        super().__init__()\n        self.embedding = torch.nn.Embedding(\n            config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16\n        )\n        self.block = torch.nn.ModuleList(\n            [\n                TransformerBlock(config, layer_idx, device)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = RMSNorm(config.hidden_size, device=device)\n        self.unembedding = torch.nn.Linear(\n            config.hidden_size,\n            config.vocab_size,\n            bias=False,\n            device=device,\n            dtype=torch.bfloat16,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.embedding(x)\n        for block in self.block:\n            x = block(x)\n        x = self.norm(x)\n        x = self.unembedding(x)\n        return x\n\n    @staticmethod\n    def from_checkpoint(\n        path: str, device: str | torch.device = \"cuda\"\n    ) -> \"Transformer\":\n        if not isinstance(device, torch.device):\n            device = torch.device(device)\n\n        config_path = os.path.join(path, \"config.json\")\n        with open(config_path, \"r\") as f:\n            json_config = json.load(f)\n            config = ModelConfig(**json_config)\n\n        model = Transformer(\n            config=config,\n            device=device,\n        )\n        model.eval()\n\n        # Load weights\n        my_rank = dist.get_rank() if dist.is_initialized() else 0\n        world_size = dist.get_world_size() if dist.is_initialized() else 1\n        per_rank_intermediate_size = config.intermediate_size // world_size\n\n        checkpoint = Checkpoint(path, device)\n\n        for name, param in model.named_parameters():\n            loaded_tensor = checkpoint.get(name)\n\n            # Note: it would be more efficient to do sharding before upcasting from MXFP4,\n            # but for simplicity we do it after.\n            if \"mlp1\" in name:  # both weight and bias\n                loaded_tensor = loaded_tensor[\n                    :,\n                    my_rank * 2\n                    * per_rank_intermediate_size : (my_rank + 1) * 2\n                    * per_rank_intermediate_size,\n                    ...,\n                ]\n            elif \"mlp2_weight\" in name:  # only weight\n                loaded_tensor = loaded_tensor[\n                    ...,\n                    my_rank\n                    * per_rank_intermediate_size : (my_rank + 1)\n                    * per_rank_intermediate_size,\n                ]\n            try:\n                param.data.copy_(loaded_tensor)\n            except:\n                print(f\"{name=} {param.data.shape=} {loaded_tensor.shape=}\")\n                raise\n\n        return model\n\n\nclass TokenGenerator:\n    @torch.inference_mode()\n    def __init__(self, checkpoint: str, device: torch.device):\n        self.device = device\n        self.model = Transformer.from_checkpoint(checkpoint, device=self.device)\n\n    @torch.inference_mode()\n    def generate(self,\n                 prompt_tokens: list[int],\n                 stop_tokens: list[int],\n                 temperature: float = 1.0,\n                 max_tokens: int = 0,\n                 return_logprobs: bool = False):\n        tokens = list(prompt_tokens)\n        num_generated_tokens = 0\n        while max_tokens == 0 or num_generated_tokens < max_tokens:\n            logits = self.model(torch.as_tensor(tokens, dtype=torch.int32, device=self.device))[-1]\n            if temperature == 0.0:\n                predicted_token = torch.argmax(logits, dim=-1).item()\n            else:\n                probs = torch.softmax(logits * (1.0 / temperature), dim=-1)\n                predicted_token = torch.multinomial(probs, num_samples=1).item()\n            tokens.append(predicted_token)\n            num_generated_tokens += 1\n\n            if return_logprobs:\n                logprobs = torch.log_softmax(logits, dim=-1)\n                selected_logprobs = logprobs[predicted_token].item()\n                yield predicted_token, selected_logprobs\n            else:\n                yield predicted_token\n\n            if predicted_token in stop_tokens:\n                break\n"
  },
  {
    "path": "gpt_oss/torch/utils.py",
    "content": "import os\nimport torch\nimport torch.distributed as dist\n\n\ndef suppress_output(rank):\n    \"\"\"Suppress printing on the current device. Force printing with `force=True`.\"\"\"\n    import builtins as __builtin__\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        if force:\n            builtin_print(\"rank #%d:\" % rank, *args, **kwargs)\n        elif rank == 0:\n            builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\n\ndef init_distributed() -> torch.device:\n    \"\"\"Initialize the model for distributed inference.\"\"\"\n    # Initialize distributed inference\n    world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n    rank = int(os.environ.get(\"RANK\", 0))\n    if world_size > 1:\n        dist.init_process_group(\n            backend=\"nccl\", init_method=\"env://\", world_size=world_size, rank=rank\n        )\n    torch.cuda.set_device(rank)\n    device = torch.device(f\"cuda:{rank}\")\n\n    # Warm up NCCL to avoid first-time latency\n    if world_size > 1:\n        x = torch.ones(1, device=device)\n        dist.all_reduce(x)\n        torch.cuda.synchronize(device)\n\n    suppress_output(rank)\n    return device\n"
  },
  {
    "path": "gpt_oss/torch/weights.py",
    "content": "import math\nimport os\n\nimport torch\nfrom safetensors import safe_open\n\n\n# Bytes per MXFP4 block: 32 FP4 numbers packed in 16 bytes\nBYTES_PER_BLOCK = 16\n\nFP4_VALUES = [\n    +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0,\n    -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,\n]\n\n# Map the names assumed in this implementation to the checkpoint names.\nPARAM_NAME_MAP = {\n    f\"block.{n}.mlp.mlp1_bias\": f\"block.{n}.mlp.mlp1_bias\" for n in range(36)\n} | {\n    f\"block.{n}.mlp.mlp1_weight\": (f\"block.{n}.mlp.mlp1_weight.blocks\", f\"block.{n}.mlp.mlp1_weight.scales\") for n in range(36)\n} | {\n    f\"block.{n}.mlp.mlp2_bias\": f\"block.{n}.mlp.mlp2_bias\" for n in range(36)\n} | {\n    f\"block.{n}.mlp.mlp2_weight\": (f\"block.{n}.mlp.mlp2_weight.blocks\", f\"block.{n}.mlp.mlp2_weight.scales\") for n in range(36)\n}\n\n\nclass Checkpoint:\n    def __init__(self, path: str, device: torch.device):\n        device_str = (\n            device.type\n            if device.index is None\n            else device.type + \":\" + str(device.index)\n        )\n        self.device_str = device_str\n\n        # Read from all files ending with .safetensors in the checkpoint directory\n        safetensor_files = [\n            os.path.join(path, fname)\n            for fname in os.listdir(path)\n            if fname.endswith(\".safetensors\")\n        ]\n        # Build a mapping from tensor name to (file, key)\n        tensor_name_to_file = {}\n        for safetensor_file in safetensor_files:\n            with safe_open(safetensor_file, framework=\"pt\", device=device_str) as f:\n                for key in f.keys():\n                    tensor_name_to_file[key] = safetensor_file\n\n        self.tensor_name_to_file = tensor_name_to_file\n\n    def get(self, name: str) -> torch.Tensor:\n        match PARAM_NAME_MAP.get(name, name):\n            case (blocks_name, scales_name):\n                # MoE weights: are in block-based MXFP4 format\n                return self._get_mxfp4_tensor(blocks_name, scales_name, dtype=torch.bfloat16)\n            case tensor_name:\n                # MoE biases and other weights\n                return self._get_tensor(tensor_name)\n\n    def _get_tensor(self, name: str) -> str:\n        assert name in self.tensor_name_to_file, f\"Tensor {name} not found in checkpoint.\"\n        with safe_open(\n            self.tensor_name_to_file[name], framework=\"pt\", device=self.device_str\n        ) as f:\n            return f.get_tensor(name)\n\n    def _get_mxfp4_tensor(\n        self,\n        blocks_name: str,\n        scales_name: str,\n        *,\n        dtype: torch.dtype = torch.bfloat16,\n        rows_per_chunk: int = 16384 * 512,\n    ) -> torch.Tensor:\n        assert blocks_name in self.tensor_name_to_file, (\n            f\"Blocks tensor {blocks_name} not found in checkpoint.\"\n        )\n        assert scales_name in self.tensor_name_to_file, (\n            f\"Scales tensor {scales_name} not found in checkpoint.\"\n        )\n\n        blocks = self._get_tensor(blocks_name)\n        scales = self._get_tensor(scales_name).to(torch.int32) - 127\n\n        assert blocks.shape[:-1] == scales.shape, (\n            f\"{blocks.shape=} does not match {scales.shape=}\"\n        )\n\n        lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)\n\n        *prefix_shape, G, B = blocks.shape\n        rows_total   = math.prod(prefix_shape) * G\n\n        blocks = blocks.reshape(rows_total, B)\n        scales = scales.reshape(rows_total, 1)\n\n        out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)\n\n        for r0 in range(0, rows_total, rows_per_chunk):\n            r1 = min(r0 + rows_per_chunk, rows_total)\n\n            blk = blocks[r0:r1]\n            exp = scales[r0:r1]\n\n            # nibble indices -> int64\n            idx_lo = (blk & 0x0F).to(torch.long)\n            idx_hi = (blk >> 4).to(torch.long)\n\n            sub = out[r0:r1]\n            sub[:, 0::2] = lut[idx_lo]\n            sub[:, 1::2] = lut[idx_hi]\n\n            torch.ldexp(sub, exp, out=sub)\n            del idx_lo, idx_hi, blk, exp\n\n        return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)\n\n    def _get_mxfp4_tensor_copy(self, blocks_name: str, scales_name: str, dtype: torch.dtype = torch.bfloat16):\n        \"short version that uses a lot of memory\"\n\n        loaded_blocks = self._get_tensor(blocks_name)\n        # Split it into low and high nibbles, upcast to bytes, and interleave (for swiglu)\n        loaded_blocks_lo = loaded_blocks & 0x0F\n        loaded_blocks_hi = loaded_blocks >> 4\n        loaded_blocks = torch.stack((loaded_blocks_lo, loaded_blocks_hi), dim=-1)\n        loaded_blocks = loaded_blocks.view(*loaded_blocks.shape[:-2], loaded_blocks.shape[-2] * 2)\n\n        loaded_scales = self._get_tensor(scales_name)\n        # Upcast to int32 and subtract bias\n        loaded_scales = loaded_scales.int() - 127\n\n        # Convert MXFP4 numbers into target dtype\n        fp4_values = torch.tensor(FP4_VALUES, dtype=dtype, device=self.device_str)\n        loaded_tensor = torch.ldexp(fp4_values[loaded_blocks.int()], loaded_scales.unsqueeze(-1))\n        loaded_tensor = loaded_tensor.view(*loaded_tensor.shape[:-2], -1)\n        return loaded_tensor\n"
  },
  {
    "path": "gpt_oss/triton/__init__.py",
    "content": ""
  },
  {
    "path": "gpt_oss/triton/attention.py",
    "content": "\"\"\"FlashAttention w/support for learned sinks and banded attention.\n\nThis is an expanded version of the Flash Attention v2 implementation (see https://tridao.me/publications/flash2/flash2.pdf)\nwhich can be found at https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html.\n\nThis version has been extended to support banded attention and learned attention sinks.\n\"\"\"\n\nimport pytest\nimport torch\n\nimport triton\nimport triton.language as tl\nfrom triton.tools.tensor_descriptor import TensorDescriptor\n\n\n\n@triton.jit\ndef _attn_fwd(\n    Q,\n    K,\n    V,\n    Sinks,\n    sm_scale,\n    M,\n    Out,  #\n    Start_q,\n    Z,\n    H,\n    N_Q_CTX,\n    N_KV_CTX,\n    HEAD_DIM: tl.constexpr,  #\n    BLOCK_M: tl.constexpr,  #\n    BLOCK_N: tl.constexpr,  #\n    BANDWIDTH: tl.constexpr,\n):\n    tl.static_assert(BLOCK_N <= HEAD_DIM)\n    start_q = tl.load(Start_q).to(tl.int32)\n    start_m = tl.program_id(0)\n    off_hz = tl.program_id(1)\n    off_z = off_hz // H\n    off_h = off_hz % H\n\n    # load attention sinks\n    if Sinks is not None:\n        sink = tl.load(Sinks + off_h).to(tl.float32)\n    else:\n        sink = 0\n\n    # initialize offsets\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N)\n    # initialize pointer to m and l\n    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink\n    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)\n    # load scales\n    qk_scale = sm_scale\n    q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])\n\n    if BANDWIDTH:\n        lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M\n    else:\n        lo, hi = start_q, start_q + (start_m + 1) * BLOCK_M\n\n    for start_n in range(lo, hi, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n\n        mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]\n\n        if BANDWIDTH:\n            too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)\n            mask = mask | too_old\n\n        k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T\n        qk = tl.dot(q, k, allow_tf32=False)\n\n        qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)\n        m_ij = tl.maximum(m_i, tl.max(qk, 1))\n        qk -= m_ij[:, None]\n\n        p = tl.math.exp(qk)\n        alpha = tl.math.exp(m_i - m_ij)\n        l_ij = tl.sum(p, 1)\n        acc = acc * alpha[:, None]\n\n        v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])\n        v = v.to(tl.float32)\n        acc = tl.dot(p, v, acc, allow_tf32=False)\n\n        l_i = l_i * alpha + l_ij\n        m_i = m_ij\n\n    sink = tl.math.exp(sink - m_i)\n    z = l_i + sink\n    acc = acc / z[:, None]\n    m_i += tl.math.log(l_i)\n    m_ptrs = M + off_hz * N_Q_CTX + offs_m\n    tl.store(m_ptrs, m_i)\n    acc = acc.to(Out.dtype)[None, None, :, :]\n    Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)\n\n\nclass _attention(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, q, k, v, sinks, sm_scale, bandwidth, start_q):\n        assert len(start_q) == 1\n        bs, n_ctx, n_kv_heads, repeat_kv, HEAD_DIM_Q = q.shape\n        bs, n_kv_ctx, n_kv_heads, HEAD_DIM_K = k.shape\n        bs, n_kv_ctx, n_kv_heads, HEAD_DIM_V = v.shape\n        n_heads = n_kv_heads * repeat_kv\n        q = q.view(bs, n_ctx, n_heads, HEAD_DIM_Q)\n        k = k.view(bs, n_kv_ctx, n_kv_heads, HEAD_DIM_K)\n        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V\n        assert HEAD_DIM_K in {16, 32, 64, 128, 256}\n\n        q = q.transpose(1, 2).contiguous()\n        k = k.repeat_interleave(repeat_kv, dim=2).transpose(1, 2).contiguous()\n        v = v.repeat_interleave(repeat_kv, dim=2).transpose(1, 2).contiguous()\n\n        BLOCK_M = 64\n        BLOCK_N = 64\n        m_pad_size = BLOCK_M - n_ctx % BLOCK_M if n_ctx % BLOCK_M != 0 else 0\n        # pad q to multiple of its block size in the n_ctx dimension (-2)\n        q = torch.nn.functional.pad(q, (0, 0, 0, m_pad_size))\n        n_pad_size = BLOCK_N - n_kv_ctx % BLOCK_N if n_kv_ctx % BLOCK_N != 0 else 0\n        # pad k and v to multiple of their block size in the n_kv_ctx dimension\n        k = torch.nn.functional.pad(k, (0, 0, 0, n_pad_size))\n        v = torch.nn.functional.pad(v, (0, 0, 0, n_pad_size))\n\n        o = torch.empty_like(q)\n        M = torch.empty((bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32)\n        grid = (triton.cdiv(n_ctx, BLOCK_M), bs * n_heads, 1)\n        _attn_fwd[grid](\n            TensorDescriptor.from_tensor(q, [1, 1, BLOCK_M, HEAD_DIM_K]),\n            TensorDescriptor.from_tensor(k, [1, 1, BLOCK_N, HEAD_DIM_K]),\n            TensorDescriptor.from_tensor(v, [1, 1, BLOCK_N, HEAD_DIM_K]),\n            sinks,\n            sm_scale,\n            M,\n            TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, HEAD_DIM_K]),\n            start_q,\n            q.shape[0],\n            q.shape[1],\n            N_Q_CTX=n_ctx + m_pad_size,\n            N_KV_CTX=n_kv_ctx,\n            HEAD_DIM=HEAD_DIM_K,\n            BANDWIDTH=bandwidth,\n            BLOCK_M=BLOCK_M,\n            BLOCK_N=BLOCK_N,\n        )\n\n        ctx.save_for_backward(q, k, v, sinks, o, M, start_q)\n        ctx.sm_scale = sm_scale\n        ctx.bandwidth = bandwidth\n\n        o = o[:, :, :n_ctx, :].transpose(1, 2).contiguous()\n        o = o.view(bs, n_ctx, n_heads * HEAD_DIM_V)\n        return o\n\n\nattention = _attention.apply\n\n\ndef attention_ref(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    sinks: torch.Tensor,\n    sm_scale: float = 0.125,\n    sliding_window: int | None = None,\n    start_q: torch.LongTensor = 0,\n):\n    batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape\n    batch_size, num_keys, num_key_value_heads, head_dim = key.shape\n\n    sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()\n    key = key.unsqueeze(3)\n    value = value.unsqueeze(3)\n\n    pos_keys = torch.arange(num_keys, device=query.device)\n    pos_queries = torch.arange(num_queries, device=query.device) + start_q\n    mask = pos_keys[None, :] > pos_queries[:, None]\n    mask = mask.float().masked_fill(mask, float(\"-inf\"))\n\n    if sliding_window:\n        too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)\n        mask.masked_fill_(too_old, float(\"-inf\"))\n\n    logits = torch.einsum(\"bqhmd,bkhmd->bhmqk\", query.float(), key.float()) * sm_scale\n    logits = logits + mask[None, None, None, :, :]\n\n    logits_max = torch.max(logits, dim=-1, keepdim=True).values\n    logits_or_sinks_max = torch.maximum(sinks, logits_max)\n    sinks = torch.exp(sinks - logits_or_sinks_max)\n    unnormalized_scores = torch.exp(logits - logits_or_sinks_max)\n    normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks\n    scores = unnormalized_scores / normalizer\n\n    output = torch.einsum(\"bhmqk,bkhmd->bqhmd\", scores, value.float())\n\n    output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups * head_dim).bfloat16()\n    return output\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 2])\n@pytest.mark.parametrize(\"num_queries\", [1, 128])\n@pytest.mark.parametrize(\"num_keys\", [128, 32])\n@pytest.mark.parametrize(\"num_key_value_heads\", [8])\n@pytest.mark.parametrize(\"num_key_value_groups\", [8])\n@pytest.mark.parametrize(\"head_dim\", [64])\n@pytest.mark.parametrize(\"sm_scale\", [0.125])\n@pytest.mark.parametrize(\"sliding_window\", [None, 128])\n@pytest.mark.parametrize(\"start_q\", [0, 5])\ndef test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_value_groups, head_dim, sm_scale, sliding_window, start_q):\n    if num_queries > num_keys:\n        pytest.skip(\"too many queries\")\n\n    q = torch.randn(batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim).bfloat16().cuda()\n    k = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda()\n    v = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda()\n    sinks = torch.randn(num_key_value_heads * num_key_value_groups).bfloat16().cuda()\n\n    start_q = torch.tensor([start_q], dtype=torch.int32).cuda()\n\n    o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q)\n    o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q)\n\n    torch.testing.assert_close(o1, o2)\n"
  },
  {
    "path": "gpt_oss/triton/model.py",
    "content": "import json\nimport math\nimport os\n\nimport torch\nfrom torch.profiler import record_function\n\nfrom gpt_oss.torch.model import ModelConfig, RMSNorm\nfrom gpt_oss.torch.weights import Checkpoint\nfrom gpt_oss.triton.attention import attention, attention_ref\nfrom gpt_oss.triton.moe import quantize_mx4, moe\n\n\nclass RotaryEmbedding(torch.nn.Module):\n    def __init__(\n        self,\n        head_dim: int,\n        base: int,\n        dtype: torch.dtype,\n        initial_context_length: int = 4096,\n        max_context_length: int = 131072,\n        scaling_factor: float = 1.0,\n        ntk_alpha: float = 1.0,\n        ntk_beta: float = 32.0,\n        device: torch.device | None = None,\n    ) -> None:\n        super().__init__()\n        self.head_dim = head_dim\n        self.base = base\n        self.dtype = dtype\n        self.initial_context_length = initial_context_length\n        self.max_context_length = max_context_length\n        self.scaling_factor = scaling_factor\n        self.ntk_alpha = ntk_alpha\n        self.ntk_beta = ntk_beta\n        self.device = device\n        self.cos, self.sin = self._compute_cos_sin(0, self.max_context_length)\n\n    def _compute_concentration_and_inv_freq(self) -> torch.Tensor:\n        \"\"\"See YaRN paper: https://arxiv.org/abs/2309.00071\"\"\"\n        freq = self.base ** (\n            torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)\n            / self.head_dim\n        )\n        if self.scaling_factor > 1.0:\n            concentration = (\n                0.1 * math.log(self.scaling_factor) + 1.0\n            )  # YaRN concentration\n\n            d_half = self.head_dim / 2\n            # NTK by parts\n            low = (\n                d_half\n                * math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))\n                / math.log(self.base)\n            )\n            high = (\n                d_half\n                * math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))\n                / math.log(self.base)\n            )\n            assert 0 < low < high < d_half - 1\n\n            interpolation = 1.0 / (self.scaling_factor * freq)\n            extrapolation = 1.0 / freq\n\n            ramp = (\n                torch.arange(d_half, dtype=torch.float32, device=freq.device) - low\n            ) / (high - low)\n            mask = 1 - ramp.clamp(0, 1)\n\n            inv_freq = interpolation * (1 - mask) + extrapolation * mask\n        else:\n            concentration = 1.0\n            inv_freq = 1.0 / freq\n\n        return concentration, inv_freq\n\n    def _compute_cos_sin(self, start: int, num_tokens: int):\n        concentration, inv_freq = self._compute_concentration_and_inv_freq()\n        t = torch.arange(start, start + num_tokens, dtype=torch.float32, device=self.device)\n        freqs = torch.einsum(\"i,j->ij\", t, inv_freq)\n        cos = freqs.cos() * concentration\n        sin = freqs.sin() * concentration\n        return cos, sin\n\n    @record_function(\"rotate\")\n    def _rotate(\n        self,\n        x: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n    ) -> torch.Tensor:\n        cos = cos[None, :, None, :].to(x.dtype)\n        sin = sin[None, :, None, :].to(x.dtype)\n        x1, x2 = torch.chunk(x, 2, dim=-1)\n        o1 = x1 * cos - x2 * sin\n        o2 = x2 * cos + x1 * sin\n        return torch.cat((o1, o2), dim=-1)\n\n    @record_function(\"rope\")\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        offset: torch.LongTensor,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        batch_size, num_tokens, num_heads, head_dim = query.shape\n        batch_size, num_tokens, num_key_value_heads, head_dim = key.shape\n\n        idx = torch.arange(num_tokens, device=query.device, dtype=torch.long) + offset\n        idx = idx % self.max_context_length\n        cos = self.cos.index_select(0, idx)\n        sin = self.sin.index_select(0, idx)\n\n        query = self._rotate(query, cos, sin)\n        key = self._rotate(key, cos, sin)\n        return query, key\n\n\nclass Cache:\n    def __init__(self, batch_size, n_ctx, n_kv_heads, d_head=64, device: torch.device | None = None):\n        self.k = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device)\n        self.v = torch.zeros((batch_size, n_ctx, n_kv_heads, d_head), dtype=torch.bfloat16, device=device)\n        self.offset = torch.zeros((1,), dtype=torch.long, device=device)\n\n    def reset(self):\n        self.k.zero_()\n        self.v.zero_()\n        self.offset.zero_()\n\n    def repeat_interleave(self, n):\n        \"\"\"Repeat each cache entry n times along the batch dimension.\"\"\"\n        self.k = self.k.repeat_interleave(n, dim=0)\n        self.v = self.v.repeat_interleave(n, dim=0)\n\n    def truncate(self, n_ctx):\n        \"\"\"Truncate the cache to the first n_ctx tokens.\"\"\"\n        batch_size, _, n_kv_heads, d_head = self.k.shape\n        assert batch_size == self.v.shape[0]\n        assert n_ctx <= self.k.shape[1]\n        self.k[:, n_ctx:, :, :].zero_()\n        self.v[:, n_ctx:, :, :].zero_()\n        self.offset.fill_(n_ctx)\n        return self.k, self.v\n\n    def extend(self, k, v):\n        batch_size, n_ctx, *_rest = k.shape\n        assert batch_size == self.k.shape[0]\n        indices = torch.arange(0, n_ctx, device=k.device, dtype=torch.long) + self.offset\n        self.k.index_copy_(1, indices, k)\n        self.v.index_copy_(1, indices, v)\n        self.offset.add_(n_ctx)\n        return self.k, self.v\n\n\nclass AttentionBlock(torch.nn.Module):\n    def __init__(\n        self,\n        config: ModelConfig,\n        layer_idx: int = 0,\n        device: torch.device | None = None,\n    ):\n        super().__init__()\n        self.head_dim = config.head_dim\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        # Only apply sliding window to every other layer\n        self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0\n        self.layer_idx = layer_idx\n        self.sinks = torch.nn.Parameter(\n            torch.empty(config.num_attention_heads, device=device, dtype=torch.bfloat16)\n        )\n        self.norm = RMSNorm(config.hidden_size, device=device)\n        qkv_dim = config.head_dim * (\n            config.num_attention_heads + 2 * config.num_key_value_heads\n        )\n        self.qkv = torch.nn.Linear(\n            config.hidden_size, qkv_dim, device=device, dtype=torch.bfloat16\n        )\n        self.out = torch.nn.Linear(\n            config.head_dim * config.num_attention_heads,\n            config.hidden_size,\n            device=device,\n            dtype=torch.bfloat16,\n        )\n        self.sm_scale = 1 / math.sqrt(config.head_dim)\n        self.rope = RotaryEmbedding(\n            config.head_dim,\n            config.rope_theta,\n            torch.float32,\n            initial_context_length=config.initial_context_length,\n            scaling_factor=config.rope_scaling_factor,\n            ntk_alpha=config.rope_ntk_alpha,\n            ntk_beta=config.rope_ntk_beta,\n            device=device,\n        )\n\n    @record_function(\"attn\")\n    def forward(self, x: torch.Tensor, cache: Cache | None = None) -> torch.Tensor:\n        batch_size, n_ctx, dim = x.shape\n\n        t = self.norm(x)\n        with record_function(\"qkv\"):\n            qkv = self.qkv(t)\n            qkv_parts = (\n                self.num_attention_heads * self.head_dim,\n                self.num_key_value_heads * self.head_dim,\n                self.num_key_value_heads * self.head_dim\n            )\n            q, k, v = torch.split(qkv, qkv_parts, dim=-1)\n            q, k, v = q.contiguous(), k.contiguous(), v.contiguous()\n\n        q = q.view(batch_size, n_ctx, self.num_attention_heads, self.head_dim)\n        k = k.view(batch_size, n_ctx, self.num_key_value_heads, self.head_dim)\n        v = v.view(batch_size, n_ctx, self.num_key_value_heads, self.head_dim)\n\n        if cache is not None:\n            offset = cache.offset.clone()\n            q, k = self.rope(q, k, offset=offset)\n            k, v = cache.extend(k, v)\n        else:\n            offset = torch.zeros((1,), dtype=torch.long, device=x.device)\n            q, k = self.rope(q, k, offset=offset)\n\n        q = q.view(\n            batch_size,\n            n_ctx,\n            self.num_attention_heads // self.num_key_value_heads,\n            self.num_key_value_heads,\n            self.head_dim,\n        )\n        with record_function(\"attn_kernel\"):\n            if n_ctx == 1:\n                t = attention_ref(\n                    q,\n                    k,\n                    v,\n                    self.sinks,\n                    self.sm_scale,\n                    self.sliding_window,\n                    offset,\n                )\n            else:\n                t = attention(\n                    q,\n                    k,\n                    v,\n                    self.sinks,\n                    self.sm_scale,\n                    self.sliding_window,\n                    offset,\n                )\n                if n_ctx < 64:\n                    t1 = attention_ref(\n                        q,\n                        k,\n                        v,\n                        self.sinks,\n                        self.sm_scale,\n                        self.sliding_window,\n                        offset,\n                    )\n                    torch.testing.assert_close(t, t1)\n                    t = t1\n\n        with record_function(\"c_proj\"):\n            t = self.out(t)\n        t = x + t\n        return t\n\n\nclass MLPBlock(torch.nn.Module):\n    def __init__(\n        self,\n        config: ModelConfig,\n        layer_idx: int = 0,\n        device: torch.device | None = None,\n    ):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.num_experts = config.num_experts\n        self.experts_per_token = config.experts_per_token\n        self.swiglu_limit = config.swiglu_limit\n        self.norm = RMSNorm(config.hidden_size, device=device)\n        self.gate = torch.nn.ParameterDict({\n            \"weight\": torch.nn.Parameter(\n                torch.empty(\n                    (config.hidden_size, config.num_experts),\n                    device=device,\n                    dtype=torch.bfloat16,\n                )\n            ),\n            \"bias\": torch.nn.Parameter(\n                torch.empty(\n                    (config.num_experts,),\n                    device=device,\n                    dtype=torch.bfloat16,\n                )\n            ),\n        })\n        self.mlp1_weight_tensor, self.mlp1_weight_mx = quantize_mx4(\n            torch.empty(\n                (\n                    config.num_experts,\n                    config.hidden_size,\n                    config.intermediate_size * 2,\n                ),\n                device=device,\n                dtype=torch.bfloat16,\n            ),\n        )\n        self.mlp1_weight = torch.nn.Parameter(self.mlp1_weight_tensor.storage.data, requires_grad=False)\n        self.mlp1_bias = torch.nn.Parameter(\n            torch.empty(\n                (config.num_experts, config.intermediate_size * 2),\n                device=device,\n                dtype=torch.bfloat16,\n            )\n        )\n        self.mlp2_weight_tensor, self.mlp2_weight_mx = quantize_mx4(\n            torch.empty(\n                (\n                    config.num_experts,\n                    config.intermediate_size,\n                    config.hidden_size,\n                ),\n                device=device,\n                dtype=torch.bfloat16,\n            ),\n        )\n        self.mlp2_weight = torch.nn.Parameter(self.mlp2_weight_tensor.storage.data, requires_grad=False)\n        self.mlp2_bias = torch.nn.Parameter(\n            torch.empty(\n                (config.num_experts, config.hidden_size),\n                device=device,\n                dtype=torch.bfloat16,\n            )\n        )\n\n    @record_function(\"mlp\")\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        batch_size, n_ctx, dim = x.shape\n        t = self.norm(x)\n\n        t = t.view(batch_size * n_ctx, dim)\n        t = moe(\n            t,\n            self.gate[\"weight\"],\n            self.mlp1_weight_tensor, self.mlp1_weight_mx,\n            self.mlp2_weight_tensor, self.mlp2_weight_mx,\n            self.gate[\"bias\"].float(),\n            self.mlp1_bias.float(),\n            self.mlp2_bias.float(),\n            experts_per_token=self.experts_per_token,\n            num_experts=self.num_experts,\n            swiglu_limit=self.swiglu_limit,\n        )\n        t = t.view(batch_size, n_ctx, dim)\n\n        return x + t\n\n\nclass TransformerBlock(torch.nn.Module):\n    def __init__(\n        self,\n        config: ModelConfig,\n        layer_idx: int,\n        device: torch.device | None = None,\n    ):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.attn = AttentionBlock(config, layer_idx, device)\n        self.mlp = MLPBlock(config, layer_idx, device)\n\n    def forward(self, x: torch.Tensor, cache: Cache | None = None) -> torch.Tensor:\n        x = self.attn(x, cache=cache)\n        x = self.mlp(x)\n        return x\n\n\nclass Transformer(torch.nn.Module):\n    def __init__(\n        self,\n        config: ModelConfig,\n        device: torch.device | None = None,\n    ):\n        super().__init__()\n        self.config = config\n        self.embedding = torch.nn.Embedding(\n            config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16\n        )\n        self.block = torch.nn.ModuleList(\n            [\n                TransformerBlock(config, layer_idx, device)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = RMSNorm(config.hidden_size, device=device)\n        self.unembedding = torch.nn.Linear(\n            config.hidden_size,\n            config.vocab_size,\n            bias=False,\n            device=device,\n            dtype=torch.bfloat16,\n        )\n\n    def forward(self, x: torch.Tensor, caches: list[Cache] | None = None) -> torch.Tensor:\n        caches=caches or [None] * len(self.block)\n        with record_function(\"embedding\"):\n            x = self.embedding(x)\n        for block, cache in zip(self.block, caches):\n            with record_function(\"block\"):\n                x = block(x, cache=cache)\n        with record_function(\"norm_f\"):\n            x = self.norm(x)\n        with record_function(\"unembedding\"):\n            x = self.unembedding(x)\n        return x.float()\n\n    @staticmethod\n    def from_checkpoint(\n        path: str, config: ModelConfig | None = None, device: str | torch.device = \"cuda\",\n    ) -> \"Transformer\":\n        if not isinstance(device, torch.device):\n            device = torch.device(device)\n\n        if config is None:\n            config_path = os.path.join(path, \"config.json\")\n            with open(config_path, \"r\") as f:\n                json_config = json.load(f)\n                config = ModelConfig(**json_config)\n\n        model = Transformer(config=config, device=device)\n        model.eval()\n\n        checkpoint = Checkpoint(path, device)\n\n        for name, param in model.named_parameters():\n            torch.cuda.empty_cache()\n            loaded_tensor = checkpoint.get(name)\n\n            if \"mlp1\" in name:\n                if \"weight\" in name:\n                    loaded_tensor, scales = quantize_mx4(loaded_tensor.mT.contiguous())\n                    _, block_index, _, _ = name.split(\".\")\n                    model.block[int(block_index)].mlp.mlp1_weight_mx = scales\n                    param.data.copy_(loaded_tensor.storage.data)\n                else:\n                    param.data.copy_(loaded_tensor)\n\n            elif \"mlp2_weight\" in name:\n                loaded_tensor, scales = quantize_mx4(loaded_tensor.mT.contiguous())\n                _, block_index, _, _ = name.split(\".\")\n                model.block[int(block_index)].mlp.mlp2_weight_mx = scales\n                param.data.copy_(loaded_tensor.storage.data)\n\n            elif \"gate\" in name and loaded_tensor.ndim == 2:\n                loaded_tensor = loaded_tensor.mT.contiguous()\n                param.data.copy_(loaded_tensor)\n\n            else:\n                param.data.copy_(loaded_tensor)\n\n        # NOTE: Required to avoid OOM errors\n        torch.cuda.empty_cache()\n        return model\n\n\nclass TokenGenerator:\n    @torch.inference_mode()\n    def __init__(self, checkpoint: str, context: int, device: torch.device):\n        self.device = device\n        self.model = Transformer.from_checkpoint(checkpoint, device=self.device)\n        self.caches = [Cache(1, context, self.model.config.num_key_value_heads, device=self.device) for _ in range(len(self.model.block))]\n        self.input_token = torch.zeros(1, dtype=torch.int32, device=self.device)\n        # warmup\n        self.model(self.input_token[None, :], caches=self.caches)\n        # capture for sampling\n        self.graph = torch.cuda.CUDAGraph()\n        with torch.cuda.graph(self.graph):\n            self.logits = self.model(self.input_token[None, :], caches=self.caches)[0]\n\n    @torch.inference_mode()\n    def generate(self,\n                 prompt_tokens: list[int],\n                 stop_tokens: list[int] | None = None,\n                 temperature: float = 1.0,\n                 max_tokens: int = 0,\n                 return_logprobs: bool = False):\n        stop_tokens = stop_tokens or []\n        for cache in self.caches:\n            cache.reset()\n        prompt_tokens = torch.as_tensor(prompt_tokens, dtype=torch.int32, device=self.device)\n        self.model(prompt_tokens[None, :-1], self.caches)\n        predicted_token = prompt_tokens[-1]\n        num_generated_tokens = 0\n        while max_tokens == 0 or num_generated_tokens < max_tokens:\n            self.input_token[0] = predicted_token\n            self.graph.replay()\n            if temperature == 0.0:\n                predicted_token = torch.argmax(self.logits[-1, :], dim=-1).item()\n            else:\n                probs = torch.softmax(self.logits * (1.0 / temperature), dim=-1)\n                predicted_token = torch.multinomial(probs[-1, :], num_samples=1).item()\n            num_generated_tokens += 1\n\n            if return_logprobs:\n                logprobs = torch.log_softmax(self.logits[-1, :], dim=-1)\n                selected_logprobs = logprobs[predicted_token].item()\n                yield predicted_token, selected_logprobs\n            else:\n                yield predicted_token\n\n            if predicted_token in stop_tokens:\n                break\n"
  },
  {
    "path": "gpt_oss/triton/moe.py",
    "content": "import torch\nfrom torch.profiler import record_function\n\nimport triton_kernels\nimport triton_kernels.swiglu\nfrom triton_kernels.numerics_details.mxfp import downcast_to_mxfp\nfrom triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation\nfrom triton_kernels.matmul_ogs import matmul_ogs\nfrom triton_kernels.numerics import InFlexData\nfrom triton_kernels.routing import routing\nfrom triton_kernels.tensor import convert_layout\nfrom triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout\nfrom triton_kernels.tensor import wrap_torch_tensor, FP4\n\n\ndef quantize_mx4(w):\n    w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)\n    w = convert_layout(wrap_torch_tensor(w, dtype=FP4), HopperMXValueLayout, mx_axis=1)\n    w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)\n    return w, w_scale\n\n\ndef swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True):\n    if interleaved:\n        x_glu, x_linear = x[..., ::2], x[..., 1::2]\n    else:\n        x_glu, x_linear = torch.chunk(x, 2, dim=-1)\n    x_glu = x_glu.clamp(min=None, max=limit)\n    x_linear = x_linear.clamp(min=-limit, max=limit)\n    out_glu = x_glu * torch.sigmoid(alpha * x_glu)\n    return out_glu * (x_linear + 1)\n\n\ndef moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True):\n    if x.numel() == 0:\n        return x\n\n    pc1 = PrecisionConfig(weight_scale=w1_mx, flex_ctx=FlexCtx(rhs_data=InFlexData()))\n    pc2 = PrecisionConfig(weight_scale=w2_mx, flex_ctx=FlexCtx(rhs_data=InFlexData()))\n    pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=InFlexData()))\n\n    with record_function(\"wg\"):\n        logits = matmul_ogs(x, wg, bg, precision_config=pcg)\n    with record_function(\"routing\"):\n        rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1)\n\n    if fused_act:\n        assert interleaved, \"Fused activation requires interleaved weights\"\n        with record_function(\"w1+swiglu\"):\n            act = FusedActivation(FnSpecs(\"swiglu\", triton_kernels.swiglu.swiglu_fn, (\"alpha\", \"limit\")), (1.702, swiglu_limit), 2)\n            x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)\n    else:\n        with record_function(\"w1\"):\n            x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1)\n        with record_function(\"swiglu\"):\n            x = swiglu(x, limit=swiglu_limit, interleaved=interleaved)\n\n    with record_function(\"w2\"):\n        x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2, gammas=rdata.gate_scal)\n    return x\n"
  },
  {
    "path": "gpt_oss/vllm/token_generator.py",
    "content": "from vllm import LLMEngine, EngineArgs, SamplingParams, TokensPrompt\n\n\nclass TokenGenerator:\n    def __init__(self, model_path: str, tensor_parallel_size: int = 1):\n        args = EngineArgs(\n            model=model_path,\n            tensor_parallel_size=tensor_parallel_size,\n        )\n        self.engine = LLMEngine.from_engine_args(args)\n        self.request_id = 0\n\n    def generate(self,\n                 prompt_tokens: list[int],\n                 stop_tokens: list[int] | None = None,\n                 temperature: float = 1.0,\n                 max_tokens: int = 0,\n                 return_logprobs: bool = False):\n        if max_tokens == 0:\n            max_tokens = None\n        request_id = str(self.request_id)\n        self.request_id += 1\n        sampling_params = SamplingParams(temperature=temperature,\n                                         max_tokens=max_tokens,\n                                         stop_token_ids=stop_tokens,\n                                         logprobs=0 if return_logprobs else None)\n        prompt = TokensPrompt(prompt_token_ids=prompt_tokens)\n        self.engine.add_request(request_id, prompt, sampling_params)\n        last_token_id = []\n        while self.engine.has_unfinished_requests():\n            step_outputs = self.engine.step()\n            output = step_outputs[0].outputs[0]\n            token_ids = output.token_ids\n            logprobs_list = output.logprobs if hasattr(output, \"logprobs\") else None\n            new_token_ids = token_ids[len(last_token_id):]\n            new_logprobs = logprobs_list[len(last_token_id):] if logprobs_list is not None else [None] * len(new_token_ids)\n            for token_id, logprobs in zip(new_token_ids, new_logprobs):\n                last_token_id.append(token_id)\n                if return_logprobs:\n                    logprob_val = None\n                    if logprobs is not None and token_id in logprobs:\n                        logprob_val = logprobs[token_id].logprob\n                    yield (token_id, logprob_val)\n                else:\n                    yield token_id\n                if stop_tokens is not None and token_id in stop_tokens:\n                    break\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"gpt-oss\"\ndescription = \"A collection of reference inference implementations for gpt-oss by OpenAI\"\n\ndependencies = [\n  \"openai-harmony\",\n  \"tiktoken>=0.9.0\",\n  \"aiohttp>=3.12.14\",\n  \"chz>=0.3.0\",\n  \"docker>=7.1.0\",\n  \"fastapi>=0.116.1\",\n  \"html2text>=2025.4.15\",\n  \"lxml>=4.9.4\",\n  \"pydantic>=2.11.7\",\n  \"structlog>=25.4.0\",\n  \"tenacity>=9.1.2\",\n  \"uvicorn>=0.35.0\",\n  \"requests>=2.31.0\",\n  \"termcolor\",\n  \"jupyter-client>=8.6.3\",\n]\nreadme = \"README.md\"\nrequires-python = \">=3.12\"\nversion = \"0.0.9\"\n\n[project.optional-dependencies]\ntriton = [\"triton>=3.4\", \"safetensors>=0.5.3\", \"torch>=2.7.0\"]\ntorch = [\"safetensors>=0.5.3\", \"torch>=2.7.0\"]\nmetal = [\"numpy\", \"tqdm\", \"safetensors\", \"torch\"]\ntest = [\"pytest>=8.4.1\", \"httpx>=0.28.1\"]\neval = [\"pandas\", \"numpy\", \"openai\", \"jinja2\", \"tqdm\", \"blobfile\"]\n\n[build-system]\nrequires = [\"setuptools>=68\"]\nbuild-backend = \"gpt_oss_build_backend.backend\"\nbackend-path = [\"_build\"]\n\n[tool.setuptools.packages.find]\ninclude = [\"gpt_oss*\"]\n\n[tool.scikit-build]\ncmake.source-dir = \".\" # pick up the root CMakeLists.txt\ncmake.args = [\n  \"-DGPTOSS_BUILD_PYTHON=ON\",\n  \"-DCMAKE_BUILD_TYPE=Release\",\n  \"-DBUILD_SHARED_LIBS=OFF\",\n]\n[tool.scikit-build.wheel]\npackages = [\"gpt_oss\"] # copy the whole Python package tree\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "import os\nimport sys\nimport pytest\nfrom typing import Generator, Any\nfrom unittest.mock import Mock, MagicMock\nfrom fastapi.testclient import TestClient\n\nsys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))\n\nfrom openai_harmony import (\n    HarmonyEncodingName,\n    load_harmony_encoding,\n)\nfrom gpt_oss.responses_api.api_server import create_api_server\n\n\n@pytest.fixture(scope=\"session\")\ndef harmony_encoding():\n    return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n\n\n@pytest.fixture\ndef mock_infer_token(harmony_encoding):\n    fake_tokens = harmony_encoding.encode(\n        \"<|channel|>final<|message|>Test response<|return|>\", \n        allowed_special=\"all\"\n    )\n    token_queue = fake_tokens.copy()\n    \n    def _mock_infer(tokens: list[int], temperature: float = 0.0, new_request: bool = False) -> int:\n        nonlocal token_queue\n        if len(token_queue) == 0:\n            token_queue = fake_tokens.copy()\n        return token_queue.pop(0)\n    return _mock_infer\n\n\n@pytest.fixture\ndef api_client(harmony_encoding, mock_infer_token) -> Generator[TestClient, None, None]:\n    app = create_api_server(\n        infer_next_token=mock_infer_token,\n        encoding=harmony_encoding\n    )\n    with TestClient(app) as client:\n        yield client\n\n\n@pytest.fixture\ndef sample_request_data():\n    return {\n        \"model\": \"gpt-oss-120b\",\n        \"input\": \"Hello, how can I help you today?\",\n        \"stream\": False,\n        \"reasoning_effort\": \"low\",\n        \"temperature\": 0.7,\n        \"tools\": []\n    }\n\n\n@pytest.fixture\ndef mock_browser_tool():\n    mock = MagicMock()\n    mock.search.return_value = [\"Result 1\", \"Result 2\"]\n    mock.open_page.return_value = \"Page content\"\n    mock.find_on_page.return_value = \"Found text\"\n    return mock\n\n\n@pytest.fixture\ndef mock_python_tool():\n    mock = MagicMock()\n    mock.execute.return_value = {\n        \"output\": \"print('Hello')\",\n        \"error\": None,\n        \"exit_code\": 0\n    }\n    return mock\n\n\n@pytest.fixture(autouse=True)\ndef reset_test_environment():\n    test_env_vars = ['OPENAI_API_KEY', 'GPT_OSS_MODEL_PATH']\n    original_values = {}\n    \n    for var in test_env_vars:\n        if var in os.environ:\n            original_values[var] = os.environ[var]\n            del os.environ[var]\n    \n    yield\n    \n    for var, value in original_values.items():\n        os.environ[var] = value\n\n\n@pytest.fixture\ndef performance_timer():\n    import time\n    \n    class Timer:\n        def __init__(self):\n            self.start_time = None\n            self.end_time = None\n        \n        def start(self):\n            self.start_time = time.time()\n        \n        def stop(self):\n            self.end_time = time.time()\n            return self.elapsed\n        \n        @property\n        def elapsed(self):\n            if self.start_time and self.end_time:\n                return self.end_time - self.start_time\n            return None\n    \n    return Timer()"
  },
  {
    "path": "tests/gpt_oss/tools/simple_browser/test_backend.py",
    "content": "import pytest\nfrom typing import Generator, Any\nfrom unittest import mock\nfrom aiohttp import ClientSession\n\nfrom gpt_oss.tools.simple_browser.backend import YouComBackend\n\nclass MockAiohttpResponse:\n    \"\"\"Mocks responses for get/post requests from async libraries.\"\"\"\n\n    def __init__(self, json: dict, status: int):\n        self._json = json\n        self.status = status\n\n    async def json(self):\n        return self._json\n\n    async def __aexit__(self, exc_type, exc, tb):\n        pass\n\n    async def __aenter__(self):\n        return self\n\ndef mock_os_environ_get(name: str, default: Any = \"test_api_key\"):\n    assert name in [\"YDC_API_KEY\"]\n    return default\n\ndef test_youcom_backend():\n    backend = YouComBackend(source=\"web\")\n    assert backend.source == \"web\"\n\n@pytest.mark.asyncio\n@mock.patch(\"aiohttp.ClientSession.get\")\nasync def test_youcom_backend_search(mock_session_get):\n    backend = YouComBackend(source=\"web\")\n    api_response = {\n        \"results\": {\n            \"web\": [\n                {\"title\": \"Web Result 1\", \"url\": \"https://www.example.com/web1\", \"snippets\": \"Web Result 1 snippets\"},\n                {\"title\": \"Web Result 2\", \"url\": \"https://www.example.com/web2\", \"snippets\": \"Web Result 2 snippets\"},\n            ],\n            \"news\": [\n                {\"title\": \"News Result 1\", \"url\": \"https://www.example.com/news1\", \"description\": \"News Result 1 description\"},\n                {\"title\": \"News Result 2\", \"url\": \"https://www.example.com/news2\", \"description\": \"News Result 2 description\"},\n            ],\n        }\n    }\n    with mock.patch(\"os.environ.get\", wraps=mock_os_environ_get):\n        mock_session_get.return_value = MockAiohttpResponse(api_response, 200)\n        async with ClientSession() as session:\n            result = await backend.search(query=\"test\", topn=10, session=session)\n        assert result.title == \"test\"\n        assert result.urls == {\"0\": \"https://www.example.com/web1\", \"1\": \"https://www.example.com/web2\", \"2\": \"https://www.example.com/news1\", \"3\": \"https://www.example.com/news2\"}\n\n@pytest.mark.asyncio\n@mock.patch(\"aiohttp.ClientSession.post\")\nasync def test_youcom_backend_fetch(mock_session_get):\n    backend = YouComBackend(source=\"web\")\n    api_response = [\n        {\"title\": \"Fetch Result 1\", \"url\": \"https://www.example.com/fetch1\", \"html\": \"<div>Fetch Result 1 text</div>\"},\n    ]\n    with mock.patch(\"os.environ.get\", wraps=mock_os_environ_get):\n        mock_session_get.return_value = MockAiohttpResponse(api_response, 200)\n        async with ClientSession() as session:\n            result = await backend.fetch(url=\"https://www.example.com/fetch1\", session=session)\n        assert result.title == \"Fetch Result 1\"\n        assert result.text == \"\\nURL: https://www.example.com/fetch1\\nFetch Result 1 text\"\n\n\n    "
  },
  {
    "path": "tests/test_api_endpoints.py",
    "content": "import pytest\nimport json\nimport asyncio\nfrom fastapi import status\nfrom unittest.mock import patch, MagicMock, AsyncMock\n\n\nclass TestResponsesEndpoint:\n    \n    def test_basic_response_creation(self, api_client, sample_request_data):\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n        data = response.json()\n        assert \"id\" in data\n        assert data[\"object\"] == \"response\"\n        assert data[\"model\"] == sample_request_data[\"model\"]\n    \n    def test_response_with_high_reasoning(self, api_client, sample_request_data):\n        sample_request_data[\"reasoning_effort\"] = \"high\"\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n        data = response.json()\n        assert \"id\" in data\n        assert data[\"status\"] == \"completed\"\n    \n    def test_response_with_medium_reasoning(self, api_client, sample_request_data):\n        sample_request_data[\"reasoning_effort\"] = \"medium\"\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n        data = response.json()\n        assert \"id\" in data\n        assert data[\"status\"] == \"completed\"\n    \n    def test_response_with_invalid_model(self, api_client, sample_request_data):\n        sample_request_data[\"model\"] = \"invalid-model\"\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        # Should still accept but might handle differently\n        assert response.status_code == status.HTTP_200_OK\n    \n    def test_response_with_empty_input(self, api_client, sample_request_data):\n        sample_request_data[\"input\"] = \"\"\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n    \n    def test_response_with_tools(self, api_client, sample_request_data):\n        sample_request_data[\"tools\"] = [\n            {\n                \"type\": \"browser_search\"\n            }\n        ]\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n    \n    def test_response_with_custom_temperature(self, api_client, sample_request_data):\n        for temp in [0.0, 0.5, 1.0, 1.5, 2.0]:\n            sample_request_data[\"temperature\"] = temp\n            response = api_client.post(\"/v1/responses\", json=sample_request_data)\n            assert response.status_code == status.HTTP_200_OK\n            data = response.json()\n            assert \"usage\" in data\n    \n    def test_streaming_response(self, api_client, sample_request_data):\n        sample_request_data[\"stream\"] = True\n        with api_client.stream(\"POST\", \"/v1/responses\", json=sample_request_data) as response:\n            assert response.status_code == status.HTTP_200_OK\n            # Verify we get SSE events\n            for line in response.iter_lines():\n                if line and line.startswith(\"data: \"):\n                    event_data = line[6:]  # Remove \"data: \" prefix\n                    if event_data != \"[DONE]\":\n                        json.loads(event_data)  # Should be valid JSON\n                        break\n\n\nclass TestResponsesWithSession:\n    \n    def test_response_with_session_id(self, api_client, sample_request_data):\n        session_id = \"test-session-123\"\n        sample_request_data[\"session_id\"] = session_id\n        \n        # First request\n        response1 = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response1.status_code == status.HTTP_200_OK\n        data1 = response1.json()\n        \n        # Second request with same session\n        sample_request_data[\"input\"] = \"Follow up question\"\n        response2 = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response2.status_code == status.HTTP_200_OK\n        data2 = response2.json()\n        \n        # Should have different response IDs\n        assert data1[\"id\"] != data2[\"id\"]\n    \n    def test_response_continuation(self, api_client, sample_request_data):\n        # Create initial response\n        response1 = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response1.status_code == status.HTTP_200_OK\n        data1 = response1.json()\n        response_id = data1[\"id\"]\n        \n        # Continue the response\n        continuation_request = {\n            \"model\": sample_request_data[\"model\"],\n            \"response_id\": response_id,\n            \"input\": \"Continue the previous thought\"\n        }\n        response2 = api_client.post(\"/v1/responses\", json=continuation_request)\n        assert response2.status_code == status.HTTP_200_OK\n\n\nclass TestErrorHandling:\n    \n    def test_missing_required_fields(self, api_client):\n        # Model field has default, so test with empty JSON\n        response = api_client.post(\"/v1/responses\", json={})\n        assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY\n    \n    def test_invalid_reasoning_effort(self, api_client, sample_request_data):\n        sample_request_data[\"reasoning_effort\"] = \"invalid\"\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        # May handle gracefully or return error\n        assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY]\n    \n    def test_malformed_json(self, api_client):\n        response = api_client.post(\n            \"/v1/responses\",\n            data=\"not json\",\n            headers={\"Content-Type\": \"application/json\"}\n        )\n        assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY\n    \n    def test_extremely_long_input(self, api_client, sample_request_data):\n        # Test with very long input\n        sample_request_data[\"input\"] = \"x\" * 100000\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n\n\nclass TestToolIntegration:\n    \n    def test_browser_search_tool(self, api_client, sample_request_data):\n        sample_request_data[\"tools\"] = [\n            {\n                \"type\": \"browser_search\"\n            }\n        ]\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n    \n    def test_function_tool_integration(self, api_client, sample_request_data):\n        sample_request_data[\"tools\"] = [\n            {\n                \"type\": \"function\",\n                \"name\": \"test_function\",\n                \"parameters\": {\"type\": \"object\", \"properties\": {}},\n                \"description\": \"Test function\"\n            }\n        ]\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n    \n    def test_multiple_tools(self, api_client, sample_request_data):\n        sample_request_data[\"tools\"] = [\n            {\n                \"type\": \"browser_search\"\n            },\n            {\n                \"type\": \"function\",\n                \"name\": \"test_function\",\n                \"parameters\": {\"type\": \"object\", \"properties\": {}},\n                \"description\": \"Test function\"\n            }\n        ]\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n\n\nclass TestPerformance:\n    \n    def test_response_time_under_threshold(self, api_client, sample_request_data, performance_timer):\n        performance_timer.start()\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        elapsed = performance_timer.stop()\n        \n        assert response.status_code == status.HTTP_200_OK\n        # Response should be reasonably fast for mock inference\n        assert elapsed < 5.0  # 5 seconds threshold\n    \n    def test_multiple_sequential_requests(self, api_client, sample_request_data):\n        # Test multiple requests work correctly\n        for i in range(3):\n            data = sample_request_data.copy()\n            data[\"input\"] = f\"Request {i}\"\n            response = api_client.post(\"/v1/responses\", json=data)\n            assert response.status_code == status.HTTP_200_OK\n\n\nclass TestUsageTracking:\n    \n    def test_usage_object_structure(self, api_client, sample_request_data):\n        response = api_client.post(\"/v1/responses\", json=sample_request_data)\n        assert response.status_code == status.HTTP_200_OK\n        data = response.json()\n        \n        assert \"usage\" in data\n        usage = data[\"usage\"]\n        assert \"input_tokens\" in usage\n        assert \"output_tokens\" in usage\n        assert \"total_tokens\" in usage\n        # reasoning_tokens may not always be present\n        # assert \"reasoning_tokens\" in usage\n        \n        # Basic validation\n        assert usage[\"input_tokens\"] >= 0\n        assert usage[\"output_tokens\"] >= 0\n        assert usage[\"total_tokens\"] == usage[\"input_tokens\"] + usage[\"output_tokens\"]\n    \n    def test_usage_increases_with_longer_input(self, api_client, sample_request_data):\n        # Short input\n        response1 = api_client.post(\"/v1/responses\", json=sample_request_data)\n        usage1 = response1.json()[\"usage\"]\n        \n        # Longer input\n        sample_request_data[\"input\"] = sample_request_data[\"input\"] * 10\n        response2 = api_client.post(\"/v1/responses\", json=sample_request_data)\n        usage2 = response2.json()[\"usage\"]\n        \n        # Longer input should use more tokens\n        assert usage2[\"input_tokens\"] > usage1[\"input_tokens\"]"
  },
  {
    "path": "tests/test_responses_api.py",
    "content": "import time\n\nimport pytest\nfrom fastapi.testclient import TestClient\nfrom openai_harmony import (\n    HarmonyEncodingName,\n    load_harmony_encoding,\n)\n\nfrom gpt_oss.responses_api.api_server import create_api_server\n\nencoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n\nfake_tokens = encoding.encode(\n    \"<|channel|>final<|message|>Hey there<|return|>\", allowed_special=\"all\"\n)\n\ntoken_queue = fake_tokens.copy()\n\n\ndef stub_infer_next_token(\n    tokens: list[int], temperature: float = 0.0, new_request: bool = False\n) -> int:\n    global token_queue\n    next_tok = token_queue.pop(0)\n    if len(token_queue) == 0:\n        token_queue = fake_tokens.copy()\n    time.sleep(0.1)\n    return next_tok\n\n\n@pytest.fixture\ndef test_client():\n    return TestClient(\n        create_api_server(infer_next_token=stub_infer_next_token, encoding=encoding)\n    )\n\n\ndef test_health_check(test_client):\n    response = test_client.post(\n        \"/v1/responses\",\n        json={\n            \"model\": \"gpt-oss-120b\",\n            \"input\": \"Hello, world!\",\n        },\n    )\n    print(response.json())\n    assert response.status_code == 200\n"
  },
  {
    "path": "tests-data/basic-event-stream.txt",
    "content": "event: response.created\ndata: {\"type\":\"response.created\",\"sequence_number\":0,\"response\":{\"id\":\"resp_687937d6852c819199d18805b160d13e0d28eb600b6e01a0\",\"object\":\"response\",\"created_at\":1752774614,\"status\":\"in_progress\",\"background\":false,\"error\":null,\"incomplete_details\":null,\"instructions\":\"You are a helpful assistant.\",\"max_output_tokens\":null,\"max_tool_calls\":null,\"model\":\"o4-mini-2025-04-16\",\"output\":[],\"parallel_tool_calls\":true,\"previous_response_id\":null,\"reasoning\":{\"effort\":\"low\",\"summary\":\"detailed\"},\"service_tier\":\"auto\",\"store\":true,\"temperature\":1.0,\"text\":{\"format\":{\"type\":\"text\"}},\"tool_choice\":\"auto\",\"tools\":[],\"top_logprobs\":0,\"top_p\":1.0,\"truncation\":\"disabled\",\"usage\":null,\"user\":null,\"metadata\":{}}}\n\nevent: response.in_progress\ndata: {\"type\":\"response.in_progress\",\"sequence_number\":1,\"response\":{\"id\":\"resp_687937d6852c819199d18805b160d13e0d28eb600b6e01a0\",\"object\":\"response\",\"created_at\":1752774614,\"status\":\"in_progress\",\"background\":false,\"error\":null,\"incomplete_details\":null,\"instructions\":\"You are a helpful assistant.\",\"max_output_tokens\":null,\"max_tool_calls\":null,\"model\":\"o4-mini-2025-04-16\",\"output\":[],\"parallel_tool_calls\":true,\"previous_response_id\":null,\"reasoning\":{\"effort\":\"low\",\"summary\":\"detailed\"},\"service_tier\":\"auto\",\"store\":true,\"temperature\":1.0,\"text\":{\"format\":{\"type\":\"text\"}},\"tool_choice\":\"auto\",\"tools\":[],\"top_logprobs\":0,\"top_p\":1.0,\"truncation\":\"disabled\",\"usage\":null,\"user\":null,\"metadata\":{}}}\n\nevent: response.output_item.added\ndata: {\"type\":\"response.output_item.added\",\"sequence_number\":2,\"output_index\":0,\"item\":{\"id\":\"rs_687937d6ed748191b23a96ac7b1b9bb60d28eb600b6e01a0\",\"type\":\"reasoning\",\"summary\":[]}}\n\nevent: response.output_item.done\ndata: {\"type\":\"response.output_item.done\",\"sequence_number\":3,\"output_index\":0,\"item\":{\"id\":\"rs_687937d6ed748191b23a96ac7b1b9bb60d28eb600b6e01a0\",\"type\":\"reasoning\",\"summary\":[]}}\n\nevent: response.output_item.added\ndata: {\"type\":\"response.output_item.added\",\"sequence_number\":4,\"output_index\":1,\"item\":{\"id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"type\":\"message\",\"status\":\"in_progress\",\"content\":[],\"role\":\"assistant\"}}\n\nevent: response.content_part.added\ndata: {\"type\":\"response.content_part.added\",\"sequence_number\":5,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"annotations\":[],\"logprobs\":[],\"text\":\"\"}}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":6,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\"Hello\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":7,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\" there\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":8,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\"!\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":9,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\" How\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":10,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\" can\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":11,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\" I\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":12,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\" assist\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":13,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\" you\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":14,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\" today\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":15,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"delta\":\"?\",\"logprobs\":[]}\n\nevent: response.output_text.done\ndata: {\"type\":\"response.output_text.done\",\"sequence_number\":16,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"text\":\"Hello there! How can I assist you today?\",\"logprobs\":[]}\n\nevent: response.content_part.done\ndata: {\"type\":\"response.content_part.done\",\"sequence_number\":17,\"item_id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"output_index\":1,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"annotations\":[],\"logprobs\":[],\"text\":\"Hello there! How can I assist you today?\"}}\n\nevent: response.output_item.done\ndata: {\"type\":\"response.output_item.done\",\"sequence_number\":18,\"output_index\":1,\"item\":{\"id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"type\":\"message\",\"status\":\"completed\",\"content\":[{\"type\":\"output_text\",\"annotations\":[],\"logprobs\":[],\"text\":\"Hello there! How can I assist you today?\"}],\"role\":\"assistant\"}}\n\nevent: response.completed\ndata: {\"type\":\"response.completed\",\"sequence_number\":19,\"response\":{\"id\":\"resp_687937d6852c819199d18805b160d13e0d28eb600b6e01a0\",\"object\":\"response\",\"created_at\":1752774614,\"status\":\"completed\",\"background\":false,\"error\":null,\"incomplete_details\":null,\"instructions\":\"You are a helpful assistant.\",\"max_output_tokens\":null,\"max_tool_calls\":null,\"model\":\"o4-mini-2025-04-16\",\"output\":[{\"id\":\"rs_687937d6ed748191b23a96ac7b1b9bb60d28eb600b6e01a0\",\"type\":\"reasoning\",\"summary\":[]},{\"id\":\"msg_687937d95cc08191aa918aa59c886a270d28eb600b6e01a0\",\"type\":\"message\",\"status\":\"completed\",\"content\":[{\"type\":\"output_text\",\"annotations\":[],\"logprobs\":[],\"text\":\"Hello there! How can I assist you today?\"}],\"role\":\"assistant\"}],\"parallel_tool_calls\":true,\"previous_response_id\":null,\"reasoning\":{\"effort\":\"low\",\"summary\":\"detailed\"},\"service_tier\":\"default\",\"store\":true,\"temperature\":1.0,\"text\":{\"format\":{\"type\":\"text\"}},\"tool_choice\":\"auto\",\"tools\":[],\"top_logprobs\":0,\"top_p\":1.0,\"truncation\":\"disabled\",\"usage\":{\"input_tokens\":18,\"input_tokens_details\":{\"cached_tokens\":0},\"output_tokens\":16,\"output_tokens_details\":{\"reasoning_tokens\":0},\"total_tokens\":34},\"user\":null,\"metadata\":{}}}\n\n"
  },
  {
    "path": "tests-data/web-search-event-stream.txt",
    "content": "event: response.created\ndata: {\"type\":\"response.created\",\"sequence_number\":0,\"response\":{\"id\":\"resp_688867b6fb90819e92212445bb8289840b8311511b435264\",\"object\":\"response\",\"created_at\":1753769911,\"status\":\"in_progress\",\"background\":false,\"error\":null,\"incomplete_details\":null,\"instructions\":\"You are a helpful assistant.\",\"max_output_tokens\":null,\"max_tool_calls\":null,\"model\":\"gpt-4.1-2025-04-14\",\"output\":[],\"parallel_tool_calls\":true,\"previous_response_id\":null,\"prompt_cache_key\":null,\"reasoning\":{\"effort\":null,\"summary\":null},\"safety_identifier\":null,\"service_tier\":\"auto\",\"store\":true,\"temperature\":1.0,\"text\":{\"format\":{\"type\":\"text\"}},\"tool_choice\":\"auto\",\"tools\":[{\"type\":\"web_search_preview\",\"search_context_size\":\"medium\",\"user_location\":{\"type\":\"approximate\",\"city\":null,\"country\":\"US\",\"region\":null,\"timezone\":null}}],\"top_logprobs\":0,\"top_p\":1.0,\"truncation\":\"disabled\",\"usage\":null,\"user\":null,\"metadata\":{}}}\n\nevent: response.in_progress\ndata: {\"type\":\"response.in_progress\",\"sequence_number\":1,\"response\":{\"id\":\"resp_688867b6fb90819e92212445bb8289840b8311511b435264\",\"object\":\"response\",\"created_at\":1753769911,\"status\":\"in_progress\",\"background\":false,\"error\":null,\"incomplete_details\":null,\"instructions\":\"You are a helpful assistant.\",\"max_output_tokens\":null,\"max_tool_calls\":null,\"model\":\"gpt-4.1-2025-04-14\",\"output\":[],\"parallel_tool_calls\":true,\"previous_response_id\":null,\"prompt_cache_key\":null,\"reasoning\":{\"effort\":null,\"summary\":null},\"safety_identifier\":null,\"service_tier\":\"auto\",\"store\":true,\"temperature\":1.0,\"text\":{\"format\":{\"type\":\"text\"}},\"tool_choice\":\"auto\",\"tools\":[{\"type\":\"web_search_preview\",\"search_context_size\":\"medium\",\"user_location\":{\"type\":\"approximate\",\"city\":null,\"country\":\"US\",\"region\":null,\"timezone\":null}}],\"top_logprobs\":0,\"top_p\":1.0,\"truncation\":\"disabled\",\"usage\":null,\"user\":null,\"metadata\":{}}}\n\nevent: response.output_item.added\ndata: {\"type\":\"response.output_item.added\",\"sequence_number\":2,\"output_index\":0,\"item\":{\"id\":\"ws_688867b77b7c819ebd9791fd981b6b560b8311511b435264\",\"type\":\"web_search_call\",\"status\":\"in_progress\",\"action\":{\"type\":\"search\"}}}\n\nevent: response.web_search_call.in_progress\ndata: {\"type\":\"response.web_search_call.in_progress\",\"sequence_number\":3,\"output_index\":0,\"item_id\":\"ws_688867b77b7c819ebd9791fd981b6b560b8311511b435264\"}\n\nevent: response.web_search_call.searching\ndata: {\"type\":\"response.web_search_call.searching\",\"sequence_number\":4,\"output_index\":0,\"item_id\":\"ws_688867b77b7c819ebd9791fd981b6b560b8311511b435264\"}\n\nevent: response.web_search_call.completed\ndata: {\"type\":\"response.web_search_call.completed\",\"sequence_number\":5,\"output_index\":0,\"item_id\":\"ws_688867b77b7c819ebd9791fd981b6b560b8311511b435264\"}\n\nevent: response.output_item.done\ndata: {\"type\":\"response.output_item.done\",\"sequence_number\":6,\"output_index\":0,\"item\":{\"id\":\"ws_688867b77b7c819ebd9791fd981b6b560b8311511b435264\",\"type\":\"web_search_call\",\"status\":\"completed\",\"action\":{\"type\":\"search\",\"query\":\"positive news stories today\"}}}\n\nevent: response.output_item.added\ndata: {\"type\":\"response.output_item.added\",\"sequence_number\":7,\"output_index\":1,\"item\":{\"id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"type\":\"message\",\"status\":\"in_progress\",\"content\":[],\"role\":\"assistant\"}}\n\nevent: response.content_part.added\ndata: {\"type\":\"response.content_part.added\",\"sequence_number\":8,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"annotations\":[],\"logprobs\":[],\"text\":\"\"}}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":9,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"As\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":10,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" of\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":11,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" July\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":12,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":13,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"29\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":14,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\",\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":15,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":16,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"202\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":17,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"5\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":18,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\",\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":19,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" one\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":20,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" uplifting\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":21,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" news\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":22,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" story\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":23,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" is\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":24,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" the\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":25,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" re\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":26,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"int\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":27,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"roduction\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":28,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" of\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":29,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" giant\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":30,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" river\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":31,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" ot\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":32,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"ters\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":33,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" to\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":34,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Argentina\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":35,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"'s\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":36,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Iber\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":37,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"á\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":38,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" wetlands\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":39,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\".\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":40,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":41,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"After\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":42,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" an\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":43,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" absence\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":44,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" of\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":45,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" over\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":46,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":47,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"40\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":48,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" years\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":49,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" due\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":50,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" to\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":51,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" habitat\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":52,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" loss\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":53,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" and\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":54,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" illegal\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":55,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" hunting\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":56,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\",\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":57,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" a\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":58,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" family\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":59,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" of\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":60,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" four\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":61,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" ot\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":62,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"ters\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":63,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\",\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":64,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" including\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":65,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" two\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":66,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" pups\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":67,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" born\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":68,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" in\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":69,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" captivity\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":70,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\",\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":71,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" has\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":72,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" been\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":73,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" released\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":74,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" into\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":75,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" their\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":76,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" original\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":77,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" habitat\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":78,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\".\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":79,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":80,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"This\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":81,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" marks\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":82,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" a\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":83,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" significant\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":84,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" step\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":85,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" in\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":86,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" conservation\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":87,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" efforts\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":88,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" to\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":89,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" restore\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":90,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" the\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":91,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" species\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":92,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" in\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":93,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" the\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":94,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" region\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":95,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\".\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":96,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":97,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"([conservationoptimism.org](https://conservationoptimism.org/7-stories-of-optimism-this-week-08-07-25-14-07-25/?utm_source=openai))\",\"logprobs\":[]}\n\nevent: response.output_text.annotation.added\ndata: {\"type\":\"response.output_text.annotation.added\",\"sequence_number\":98,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"annotation_index\":0,\"annotation\":{\"type\":\"url_citation\",\"end_index\":529,\"start_index\":398,\"title\":\"7 stories of optimism this week (08.07.25-14.07.25) - Conservation Optimism\",\"url\":\"https://conservationoptimism.org/7-stories-of-optimism-this-week-08-07-25-14-07-25/?utm_source=openai\"}}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":99,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"\\n\\nAdditionally\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":100,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\",\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":101,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" the\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":102,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" River\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":103,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Seine\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":104,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" in\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":105,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Paris\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":106,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" has\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":107,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" reopened\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":108,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" to\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":109,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" swimmers\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":110,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" for\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":111,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" the\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":112,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" first\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":113,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" time\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":114,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" since\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":115,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":116,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"192\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":117,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"3\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":118,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\".\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":119,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":120,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"Following\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":121,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" a\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":122,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" $\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":123,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"1\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":124,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\".\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":125,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"6\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":126,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" billion\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":127,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" cleanup\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":128,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\",\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":129,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" three\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":130,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" designated\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":131,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" areas\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":132,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" near\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":133,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" landmarks\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":134,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" like\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":135,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" the\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":136,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Eiffel\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":137,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Tower\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":138,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" and\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":139,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Notre\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":140,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Dame\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":141,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Cathedral\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":142,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" now\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":143,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" allow\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":144,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" public\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":145,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" swimming\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":146,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\",\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":147,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" providing\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":148,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" Par\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":149,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"isi\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":150,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"ans\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":151,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" and\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":152,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" visitors\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":153,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" a\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":154,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" unique\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":155,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" way\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":156,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" to\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":157,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" enjoy\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":158,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" the\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":159,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" city\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":160,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\".\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":161,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":162,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"([onlygoodnewsdaily.com](https://www.onlygoodnewsdaily.com/post/just-good-news-7-july-2025?utm_source=openai))\",\"logprobs\":[]}\n\nevent: response.output_text.annotation.added\ndata: {\"type\":\"response.output_text.annotation.added\",\"sequence_number\":163,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"annotation_index\":1,\"annotation\":{\"type\":\"url_citation\",\"end_index\":947,\"start_index\":837,\"title\":\"Today's Good News | OGN Daily\",\"url\":\"https://www.onlygoodnewsdaily.com/post/just-good-news-7-july-2025?utm_source=openai\"}}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":164,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\"\\n\\nThese\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":165,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" stories\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":166,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" highlight\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":167,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" ongoing\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":168,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" global\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":169,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" efforts\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":170,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" to\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":171,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" restore\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":172,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" natural\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":173,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" habitats\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":174,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" and\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":175,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" enhance\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":176,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" urban\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":177,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" environments\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":178,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" for\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":179,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" public\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":180,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" enjoyment\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":181,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\".\",\"logprobs\":[]}\n\nevent: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"sequence_number\":182,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"delta\":\" \",\"logprobs\":[]}\n\nevent: response.output_text.done\ndata: {\"type\":\"response.output_text.done\",\"sequence_number\":183,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"text\":\"As of July 29, 2025, one uplifting news story is the reintroduction of giant river otters to Argentina's Iberá wetlands. After an absence of over 40 years due to habitat loss and illegal hunting, a family of four otters, including two pups born in captivity, has been released into their original habitat. This marks a significant step in conservation efforts to restore the species in the region. ([conservationoptimism.org](https://conservationoptimism.org/7-stories-of-optimism-this-week-08-07-25-14-07-25/?utm_source=openai))\\n\\nAdditionally, the River Seine in Paris has reopened to swimmers for the first time since 1923. Following a $1.6 billion cleanup, three designated areas near landmarks like the Eiffel Tower and Notre Dame Cathedral now allow public swimming, providing Parisians and visitors a unique way to enjoy the city. ([onlygoodnewsdaily.com](https://www.onlygoodnewsdaily.com/post/just-good-news-7-july-2025?utm_source=openai))\\n\\nThese stories highlight ongoing global efforts to restore natural habitats and enhance urban environments for public enjoyment. \",\"logprobs\":[]}\n\nevent: response.content_part.done\ndata: {\"type\":\"response.content_part.done\",\"sequence_number\":184,\"item_id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"output_index\":1,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"annotations\":[{\"type\":\"url_citation\",\"end_index\":529,\"start_index\":398,\"title\":\"7 stories of optimism this week (08.07.25-14.07.25) - Conservation Optimism\",\"url\":\"https://conservationoptimism.org/7-stories-of-optimism-this-week-08-07-25-14-07-25/?utm_source=openai\"},{\"type\":\"url_citation\",\"end_index\":947,\"start_index\":837,\"title\":\"Today's Good News | OGN Daily\",\"url\":\"https://www.onlygoodnewsdaily.com/post/just-good-news-7-july-2025?utm_source=openai\"}],\"logprobs\":[],\"text\":\"As of July 29, 2025, one uplifting news story is the reintroduction of giant river otters to Argentina's Iberá wetlands. After an absence of over 40 years due to habitat loss and illegal hunting, a family of four otters, including two pups born in captivity, has been released into their original habitat. This marks a significant step in conservation efforts to restore the species in the region. ([conservationoptimism.org](https://conservationoptimism.org/7-stories-of-optimism-this-week-08-07-25-14-07-25/?utm_source=openai))\\n\\nAdditionally, the River Seine in Paris has reopened to swimmers for the first time since 1923. Following a $1.6 billion cleanup, three designated areas near landmarks like the Eiffel Tower and Notre Dame Cathedral now allow public swimming, providing Parisians and visitors a unique way to enjoy the city. ([onlygoodnewsdaily.com](https://www.onlygoodnewsdaily.com/post/just-good-news-7-july-2025?utm_source=openai))\\n\\nThese stories highlight ongoing global efforts to restore natural habitats and enhance urban environments for public enjoyment. \"}}\n\nevent: response.output_item.done\ndata: {\"type\":\"response.output_item.done\",\"sequence_number\":185,\"output_index\":1,\"item\":{\"id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"type\":\"message\",\"status\":\"completed\",\"content\":[{\"type\":\"output_text\",\"annotations\":[{\"type\":\"url_citation\",\"end_index\":529,\"start_index\":398,\"title\":\"7 stories of optimism this week (08.07.25-14.07.25) - Conservation Optimism\",\"url\":\"https://conservationoptimism.org/7-stories-of-optimism-this-week-08-07-25-14-07-25/?utm_source=openai\"},{\"type\":\"url_citation\",\"end_index\":947,\"start_index\":837,\"title\":\"Today's Good News | OGN Daily\",\"url\":\"https://www.onlygoodnewsdaily.com/post/just-good-news-7-july-2025?utm_source=openai\"}],\"logprobs\":[],\"text\":\"As of July 29, 2025, one uplifting news story is the reintroduction of giant river otters to Argentina's Iberá wetlands. After an absence of over 40 years due to habitat loss and illegal hunting, a family of four otters, including two pups born in captivity, has been released into their original habitat. This marks a significant step in conservation efforts to restore the species in the region. ([conservationoptimism.org](https://conservationoptimism.org/7-stories-of-optimism-this-week-08-07-25-14-07-25/?utm_source=openai))\\n\\nAdditionally, the River Seine in Paris has reopened to swimmers for the first time since 1923. Following a $1.6 billion cleanup, three designated areas near landmarks like the Eiffel Tower and Notre Dame Cathedral now allow public swimming, providing Parisians and visitors a unique way to enjoy the city. ([onlygoodnewsdaily.com](https://www.onlygoodnewsdaily.com/post/just-good-news-7-july-2025?utm_source=openai))\\n\\nThese stories highlight ongoing global efforts to restore natural habitats and enhance urban environments for public enjoyment. \"}],\"role\":\"assistant\"}}\n\nevent: response.completed\ndata: {\"type\":\"response.completed\",\"sequence_number\":186,\"response\":{\"id\":\"resp_688867b6fb90819e92212445bb8289840b8311511b435264\",\"object\":\"response\",\"created_at\":1753769911,\"status\":\"completed\",\"background\":false,\"error\":null,\"incomplete_details\":null,\"instructions\":\"You are a helpful assistant.\",\"max_output_tokens\":null,\"max_tool_calls\":null,\"model\":\"gpt-4.1-2025-04-14\",\"output\":[{\"id\":\"ws_688867b77b7c819ebd9791fd981b6b560b8311511b435264\",\"type\":\"web_search_call\",\"status\":\"completed\",\"action\":{\"type\":\"search\",\"query\":\"positive news stories today\"}},{\"id\":\"msg_688867b99c54819e8db837fcf08da9040b8311511b435264\",\"type\":\"message\",\"status\":\"completed\",\"content\":[{\"type\":\"output_text\",\"annotations\":[{\"type\":\"url_citation\",\"end_index\":529,\"start_index\":398,\"title\":\"7 stories of optimism this week (08.07.25-14.07.25) - Conservation Optimism\",\"url\":\"https://conservationoptimism.org/7-stories-of-optimism-this-week-08-07-25-14-07-25/?utm_source=openai\"},{\"type\":\"url_citation\",\"end_index\":947,\"start_index\":837,\"title\":\"Today's Good News | OGN Daily\",\"url\":\"https://www.onlygoodnewsdaily.com/post/just-good-news-7-july-2025?utm_source=openai\"}],\"logprobs\":[],\"text\":\"As of July 29, 2025, one uplifting news story is the reintroduction of giant river otters to Argentina's Iberá wetlands. After an absence of over 40 years due to habitat loss and illegal hunting, a family of four otters, including two pups born in captivity, has been released into their original habitat. This marks a significant step in conservation efforts to restore the species in the region. ([conservationoptimism.org](https://conservationoptimism.org/7-stories-of-optimism-this-week-08-07-25-14-07-25/?utm_source=openai))\\n\\nAdditionally, the River Seine in Paris has reopened to swimmers for the first time since 1923. Following a $1.6 billion cleanup, three designated areas near landmarks like the Eiffel Tower and Notre Dame Cathedral now allow public swimming, providing Parisians and visitors a unique way to enjoy the city. ([onlygoodnewsdaily.com](https://www.onlygoodnewsdaily.com/post/just-good-news-7-july-2025?utm_source=openai))\\n\\nThese stories highlight ongoing global efforts to restore natural habitats and enhance urban environments for public enjoyment. \"}],\"role\":\"assistant\"}],\"parallel_tool_calls\":true,\"previous_response_id\":null,\"prompt_cache_key\":null,\"reasoning\":{\"effort\":null,\"summary\":null},\"safety_identifier\":null,\"service_tier\":\"default\",\"store\":true,\"temperature\":1.0,\"text\":{\"format\":{\"type\":\"text\"}},\"tool_choice\":\"auto\",\"tools\":[{\"type\":\"web_search_preview\",\"search_context_size\":\"medium\",\"user_location\":{\"type\":\"approximate\",\"city\":null,\"country\":\"US\",\"region\":null,\"timezone\":null}}],\"top_logprobs\":0,\"top_p\":1.0,\"truncation\":\"disabled\",\"usage\":{\"input_tokens\":320,\"input_tokens_details\":{\"cached_tokens\":0},\"output_tokens\":256,\"output_tokens_details\":{\"reasoning_tokens\":0},\"total_tokens\":576},\"user\":null,\"metadata\":{}}}"
  }
]